読者です 読者をやめる 読者になる 読者になる

凹みTips

C++、JavaScript、Unity、ガジェット等の Tips について雑多に書いています。

C++による線形補間

数値計算 C++

MathcadをC++に翻訳しよう,という試みをしているので,積分やら挟み打ち法によるroot計算やらを実装したいのですが,その一貫として線形補間クラスを作成しました.
中身は至ってシンプルで,

  1. x, y配列(vectorなど)を与える
  2. それぞれの配列が2個以上データを持ってるか確認(直線が定義できないので)
  3. それぞれの要素数が一致しているか確認
  4. x要素が単調増加しているか確認

コンストラクタを動かした後,Get(x)で所望のyを得られるというものです.線形補間しているGetの中身は,

  1. 範囲の外側なら端の2点で線形補間
  2. 範囲の内側なら,どこのxに対応するかを探して前後の点で線形補間

といったものです.
コードは以下になります.
lerp.h

#include <iostream>

template <class T, template <class A, class Allocator = std::allocator<A> > class Container = std::vector>
class CLerp
{
private:
	bool Valid;
	Container<T> X, Y;
	const int N;

public:
	CLerp(Container<T> x, Container<T> y) : Valid(false), X(x), Y(y), N(x.size()-1)
	{
		// 要素数が2個以上か調べる
		if (X.size() < 2 || Y.size() < 2) {
			std::cout << "Error! The size of X or Y must be greater than 2." << std::endl;
			return;
		}

		// 要素数が同じか調べる
		if (X.size() != Y.size()) {
			std::cout << "Error! The size of X and Y are different." << std::endl;
			return;
		}

		// 単調増加か調べる
		for (int i=0; i<N-1; i++) {
			if (X[i] > X[i+1]) {
				std::cout << "Error! X must be monotonically increasing." << std::endl;
				return;
			}
		}

		Valid = true;
	}

	T Get(T x) {
		// コンストラクタが正常でない場合終了
		if (!Valid) {
			return 0;
		}

		// 最初の要素より小さかった場合,最初の2つの要素を線形補間
		if (x < X[0]) {
			return Y[0] + (Y[1] - Y[0])/(X[1] - X[0]) * (x - X[0]);
		}
		// 最後の要素より大きかった場合,最後の2つの要素を線形補間
		if (x > X[N]) {
			return Y[N] + (Y[N] - Y[N-1])/(X[N] - X[N-1]) * (x - X[N]);
		}
		// 範囲内の場合
		int cnt = 0, prev, next;
		bool flag = false;
		while (cnt < N) {
			if (x >= X[cnt] && x <= X[cnt+1]) {
				prev = cnt;
				next = cnt+1;
				flag = true;
				break;
			}
			cnt++;
		}
		return Y[prev] + (Y[next] - Y[prev])/(X[next] - X[prev]) * (x - X[prev]);
	}

};

馬鹿正直な作り方です.もっとスマートに作りたかった….でも動けばよし.

ついでに,等間隔なベクトルデータを生成したり,関数ポインタで値を代入する関数を用意しておきます.

calc.h

#include <iostream>
#include <vector>
#include <algorithm>

// 等間隔ベクトルをセットする関数
// first, first+(end-first)/div, ... , end?
template <class T, template <class A, class Allocator = std::allocator<A> > class Container>
void setVector(Container<T> &cont, const T first, const T end, const int div)
{
	if (div <= 0) {
		std::cout << "Error! div <= 0 (@setVector)" << std::endl;
	}
	for (int i=0; i<=div; i++) {
		cont.push_back(first + ((end - first)*(T)i)/(T)div);
	}
}

// 関数ポインタでベクトルをセットする関数
template <class T, template <class A, class Allocator = std::allocator<A> > class Container>
void setVector(Container<T> &cont, Container<T> &x, T (*setFunc)(T))
{
	cont.clear();
	std::transform(x.begin(), x.end(), std::back_inserter(cont), setFunc);
}

CLerpクラスのメンバ関数ポインタをsetFuncに投げて代入したり出来たら良いなぁと思っていて,以下のようなコードを書いてみたのですが,コンパイラが落ちてしまって上手く動きませんでした.

template <class T, template <class A, class Allocator = std::allocator<A> > class Container>
T (*getLerp(CLerp<T, Container> &lerp))(T)
{
	T (CLerp<T, Container>::*pGet)(T) = &CLerp<T, Container>::Get;

	return (lerp.*pGet);
}

仕方なく,calc.hに以下のようなコードを追加して泣く泣く妥協しました.

calc.h

〜 前略 〜
// ベクトルをセットする関数(Getを持つクラス(ex. CLerp)ver.)
template <class T, template <class A, class Allocator = std::allocator<A> > class Container, class Sub>
void setVector(Container<T> &cont, Container<T> &x, Sub &s)
{
	cont.clear();
	Container<T>::iterator it = x.begin();
	while (it != x.end()) {
		cont.push_back(s.Get(*it));
		it++;
	}
}

これを用いて以下のようなコードを実行してみます.なお,プロットには前回(d:id:hecomi:20100709)作ったCGnuplotを用いています.

#include <iostream>
#include <vector>
#include "calc.h"
#include "lerp.h"
#include "gnuplot.h"

using namespace std;

double hoge(double x)
{
	return 3*x*x*x*x - 2*x*x*x + 5*x*x - x;
}

int main()
{
	vector<double> x1, x2, sp1, sp2;
	// 適当にプロット用ベクトルを作成
	setVector(x1, -2.0, 2.0, 100);
	// 関数hogeに従ったグラフを作成
	setVector(sp1, x1, hoge);
	// 線形補間実行
	CLerp<double> lp(x1, sp1);
	// 範囲の外にまではみ出るプロット用ベクトルを作成
	setVector(x2, -5.0, 5.0, 100);
	// x2に応じて線形補間したhogeをプロット
	setVector(sp2, x2, lp);
	// 出力
	CGnuplot gp;
	gp.Plot(x2, sp2);

	int num;
	cin >> num;
}

得られた結果が次に成ります.

成功したことが分かります.
にしても,コンパイル通らないエラーは何故だろう….どなたか助けてください.

追記(10/07/11)

なんて実装をしておりましたが,よくよく考えたらSTL見習ってファンクタ使った実装にすれば全て解決ですよね.

T Get(T x) {...}

T operator()(T x) {...}

へと変更.calc.hを,

template <class T, template <class A, class Allocator = std::allocator<A> > class Container, typename Functor>
void setVector(Container<T> &cont, Container<T> &x, Functor &func)
{
	cont.clear();
	std::transform(x.begin(), x.end(), std::back_inserter(cont), func);
}

とすれば,CLerp版のsetVectorとか関数ポインタ用などと分けなくても良くなったので,同じmain()で動作しますね.

追記2(10/07/17)

boost::lambda::bindなどで使えるように、改良しました。

#include <iostream>
#include <functional>

template <class T, template <class A, class Allocator = std::allocator<A> > class Container = std::vector>
class CLerp : public std::unary_function<T, T> // bindで使えるようにするため
{
private:
	bool Valid;
	Container<T> X, Y;
	const int N;

public:
	CLerp(Container<T> x, Container<T> y) : Valid(false), X(x), Y(y), N(x.size()-1)
	{
		// 要素数が2個以上か調べる
		if (X.size() < 2 || Y.size() < 2) {
			std::cout << "Error! The size of X or Y must be greater than 2." << std::endl;
			return;
		}

		// 要素数が同じか調べる
		if (X.size() != Y.size()) {
			std::cout << "Error! The size of X and Y are different." << std::endl;
			return;
		}

		// 単調増加か調べる
		for (int i=0; i<N-1; i++) {
			if (X[i] > X[i+1]) {
				std::cout << "Error! X must be monotonically increasing." << std::endl;
				return;
			}
		}

		Valid = true;
	}

	T operator()(T x) const {
		// コンストラクタが正常でない場合終了
		if (!Valid) {
			return 0;
		}

		// 最初の要素より小さかった場合,最初の2つの要素を線形補間
		if (x < X[0]) {
			return Y[0] + (Y[1] - Y[0])/(X[1] - X[0]) * (x - X[0]);
		}
		// 最後の要素より大きかった場合,最後の2つの要素を線形補間
		if (x > X[N]) {
			return Y[N] + (Y[N] - Y[N-1])/(X[N] - X[N-1]) * (x - X[N]);
		}
		// 範囲内の場合
		int cnt = 0, prev, next;
		bool flag = false;
		while (cnt < N) {
			if (x >= X[cnt] && x <= X[cnt+1]) {
				prev = cnt;
				next = cnt+1;
				flag = true;
				break;
			}
			cnt++;
		}
		return Y[prev] + (Y[next] - Y[prev])/(X[next] - X[prev]) * (x - X[prev]);
	}

};