凹みTips

C++、JavaScript、Unity、ガジェット等の Tips について雑多に書いています。

OLL によるオンライン学習を試してみた

はじめに

音声家電制御に用いている音声認識エンジン Julius の認識結果の精度を高めるために、機械学習を取り入れようと思います。機械学習を Julius に取り込んだ例としては、ルールベースjuliusの誤認識対策にSVMを利用してみよう - お前の血は何色だ!! 4 にて、rtiさんが liblinear を用いたものを紹介してされています。しかしながら liblinear ではオンライン学習できないことを課題として挙げられていました。そこでうまくやってくれそうな C++ ライブラリを探したところ、oll - oll: Online-Learning Library - Google Project Hosting を見つけました。以下、Wiki(OllMainJa - oll - oll: Online-Learning Library - Google Project Hosting)に書いてあった説明文を抜粋:

OLL は様々なオンライン学習をサポートした機械学習ライブラリであり、特に自然言語処理など、大規模、かつ疎な学習問題に最適化されています。これらのオンライン学習手法は速度面、作業領域面で非常に効率的(学習サンプル数、素性種類数に比例)でありながら、SVMsやMEsなどのバッチ学習と同程度の精度を達成します。
学習、推定を行なうプログラムとC++ libraryを提供しています。

現在サポートしている学習手法は次の通りです。

  • Perceptron [F. Rosenblatt 1958]
  • Averaged Perceptron [M. Collins 2002]
  • Passive Agressive (PA, PA-I, PA-II) [K. Crammer, et. al. 2006]
  • ALMA (modified slightly from original) [H. Daume, 2007]
  • Confidence Weighted Linear-Classification [M. Dredze, 2008]

とのことです。扱いやすくするために、基本的な機能をまとめたクラスを書いてみました。

こんな風に使えるようになります

test.cpp
#include <iostream>
#include "oll_class.hpp"

int main(int argc, char const* argv[])
{
	{
		OLL<oll_tool::PA1> oll;
		oll.add(1,  "0:1.0 1:2.0 2:-1.0");
		oll.add(-1, "0:-0.5 1:1.0 2:-0.5");
		std::cout << oll.test("0:1.0 1:1.0") << std::endl;
		oll.save("test.oll");
	}

	{
		OLL<oll_tool::PA1> oll;
		oll.load("test.oll");
		std::cout << oll.test("0:1.0 1:1.0") << std::endl;
	}
	return 0;
}
コンパイル
$ g++ test.cpp oll.cpp -o test
出力
$ ./test
0.171429
0.171429

add で「ラベル:数値 ラベル:数値 ...」として文章をデータとして追加し、test で試したい同様なフォーマットの文章をデータとして与えると精度が出力されます。
save でファイルに学習結果を保存、load で取り出すことが出来ます。

コード

oll_line.cpp を参考に書いてみました。

oll_class.hpp
#include <string>
#include <iostream>
#include <cstdlib>
#include "oll.hpp"

template<int TrainMethodNum> struct train_method { typedef void type; };
template<> struct train_method<0> { typedef oll_tool::P_s   type; }; // Perceptron
template<> struct train_method<1> { typedef oll_tool::AP_s  type; }; // Averaged Perceptron
template<> struct train_method<2> { typedef oll_tool::PA_s  type; }; // Passive Agressive
template<> struct train_method<3> { typedef oll_tool::PA1_s type; }; // Passive Agressive L1
template<> struct train_method<4> { typedef oll_tool::PA2_s type; }; // Passive Agressive L2
template<> struct train_method<5> { typedef oll_tool::PAK_s type; }; // Kernelized Passive Agressive
template<> struct train_method<6> { typedef oll_tool::CW_s  type; }; // Confidence Weighted
template<> struct train_method<7> { typedef oll_tool::AL_s  type; }; // ALMA HD

/**
 *  オンライン学習ライブラリの機能をまとめたクラス
 *  @template TrainMethodNum oll_tool::学習手法(P, AP, PA, PA1, PA2, PAK, CW, AL)
 */
template <int TrainMethodNum = oll_tool::PA1>
class OLL
{
public:
	typedef typename train_method<TrainMethodNum>::type TrainMethod;

	/**
	 *  コンストラクタ
	 *  @param[in] C    Regularization Parameter
	 *  @param[in] bias Bias
	 */
	OLL(float C = 1.f, float bias = 0.f)
	: tm_( static_cast<oll_tool::trainMethod>(TrainMethodNum) )
	{
		ol_.setC(C);
		ol_.setBias(bias);
	}

	/**
	 *  学習結果をファイルに保存
	 *  @param[in] file_name 保存先ファイル名
	 */
	bool save(const std::string& file_name)
	{
		if ( ol_.save(file_name.c_str()) == -1) {
			std::cerr << ol_.getErrorLog() << std::endl;
			return false;
		}
		return true;
	}

	/**
	 *  学習結果をファイルから復元
	 *  @param[in] file_name 復元元ファイル名
	 */
	bool load(const std::string& file_name)
	{
		if ( ol_.load(file_name.c_str()) == -1) {
			std::cerr << ol_.getErrorLog() << std::endl;
			return false;
		}
		return true;
	}

	/**
	 *  データを渡して学習させる
	 *  @param[in] flag true: +のデータ、false: -のデータ
	 *  @param[in] data 学習データ (format: id:val id:val ...)
	 */
	bool add(int flag, const std::string& data)
	{
		std::string format = ( (flag > 0) ? "1 " : "-1 " ) + data;
		oll_tool::fv_t fv;
		int y = 0;

		if (ol_.parseLine(format, fv, y) == -1) {
			std::cerr << ol_.getErrorLog() << std::endl;
			return false;
		}

		TrainMethod a;
		ol_.trainExample(a, fv, y);

		return true;
	}

	/**
	 *  データをテストする
	 *  @param[in] data テストデータ : id:val id:val ...
	 */
	float test(const std::string& data)
	{
		std::string format = "0 " + data;
		oll_tool::fv_t fv;
		int y = 0;

		if (ol_.parseLine(format, fv, y) == -1) {
			std::cerr << ol_.getErrorLog() << std::endl;
			return -1.f;
		}

		return ol_.classify(fv);
	}

private:
	oll_tool::oll ol_;
	oll_tool::trainMethod tm_;
};

今後の展望

Julius の認識結果を TTS して、違う場合は「いや」とか「違う」と言ってあげることで誤認識として学習させ、合っていた場合はそのまま赤外線を発信して家電を操作、これを学習、みたいな形でやっていけば段々と賢くなっていく音声家電制御システムが出来上がるんじゃないでしょうか。と思って今後も作っていきます。