技術者ブログ
クラウド型WAF「Scutum(スキュータム)」の開発者/エンジニアによるブログです。
金床“Kanatoko”をはじめとする株式会社ビットフォレストの技術チームが、“WAFを支える技術”をテーマに幅広く、不定期に更新中!
Isolation ForestのJavaによる高速な実装をオープンソースで公開
はじめに
Scutumにおいて教師なしの異常検知アルゴリズムであるIsolation Forestを使うため、フルスクラッチでJavaで実装し、GitHubでオープンソースで公開しました。今回はこの実装について簡単に紹介したいと思います。
WAFでは絶え間なく沢山の通信を処理しますが、このうちの殆どは攻撃ではない正常な通信であり、ごく一部が攻撃、つまり異常になります。データが多いことから、教師なし学習で異常を見つけることで、WAFの防御性能を高めることを目指しています。
Isolation Forestの概要
2020年現在、機械学習の主流は「教師あり学習」ですが、Isolation Forestは「教師なし学習」に属します。「教師なし学習」ではグループ分けを行うクラスタリングが有名ですが、Isolation Forestはグループ分けではなく、「異常検知」あるいは「外れ値検知」を行うためのアルゴリズムです。
なぜIsolation Forestなのか?
「Kaggleで勝つデータ分析の技術」という書籍などを中心に情報収集した結果、Kaggleにおいて「性能のよい教師あり学習アルゴリズム」として立証されているものはランダムフォレストの進化系、つまり「木を大量に作って、多数決を採る」を基本にするものになっているようです。これの「教師なし学習版」はすなわちIsolation Forestとなるため、性能の良さが期待できるのではないかと考えました。
ScutumではこれまでIsolation Forestではなく、独自に開発したXBOSというアルゴリズムを使っていましたが、これを今回Isolation Forestに切り替えました。
XBOS導入時にブログを書いた時点では1つのデータセット(Kaggleクレジットカード詐欺データセット)のみの結果からXBOSの方が性能が良いと判断していました。しかし上記の理由からIsolation Forestに興味が出たため、さらに他の多くのデータセットを使って検証してみたところ、Isolation ForestがXBOSよりも良い結果を出す回数がかなり多かったのです(※1)。
Isolation Forestのアルゴリズム
この項ではざっくりとしたアルゴリズムの説明を行いますが、決定木については何となく知っているという事を前提としていますのでご注意ください。また、詳細については元の論文を参照してください。
Isolation Forestはシンプルなアルゴリズムです。まずランダムに特徴を1つ選びます。
次に、その特徴において、またしてもランダムに、区切り値(split point)を選びます。この区切り値は、その特徴における全データの値の、最小値と最大値の間のどこかになるように選びます。例えばある特徴について見たときに、下記の7個のデータがあるとします。
3,4,4,6,7,10,31最小値が3で、最大値が31です。この場合には、ランダムに3〜30の範囲でsplit pointを1つ選択します。
この時点で
- どの特徴にするか
- どの値にするか
1つノードを作ったら、全データをそのノードで処理します。そしてsplit point以下か、そうでないかによって、全データはそれぞれ2つのグループに分けられることになります。分けられたグループはさらにその先で同じように別のノードを作り、そのノードで処理して2つのグループに分け...ということを繰り返します。この繰り返しによって木が成長します。
データはどんどん細かいグループに分けられていき、最終的にはたった1つのデータしか到達しないノードとなりますが、そこまで行ったらノードの新規作成(木の成長)を終了します。あるいは、ある程度の深さ(あらかじめ決められた値)にノードが到達したら、そこで同じく終了とします。このようなノードの連なりによって定義された1つの木を「Isolation Tree」と呼びます。
生成されたIsolation Treeは、異常(外れ値)を見つけるための探索木として使うことができます。あるデータについて探索を行うと、外れ値は極端に大きかったり小さかったりするため、比較的浅い位置にあるノードで探索が終了します。一方で正常なデータは深い位置のノードに至るまで探索が終わりません(もちろん、1つの木はランダムな値から生成されていることから、これはあくまで確率的な話です)。
あるデータについて、探索が終わったノードの「深さ」が異常さを表すスコアとなります。浅い位置で探索が終わるほど異常なので、値が小さいほど異常、ということになります
そして最終的には1つの木だけではなく、100あるいはそれ以上の、大量の木を生成します。そして全ての木における、あるデータの探索結果(ノードの深さ)の平均値を算出し、それを「あるデータの異常さ」と考えるわけです。
このように大量の「Isolation Tree」から構成されるため、「Isolation Forest」という名前が付いています。
データが多くある場合には、1つの木を作る際に全部を使うことはせず、あらかじめランダムにサンプリングしたデータのサブセットを使います。このサンプリングをそれぞれの木ごとに毎回行うことで、多様性を確保し、アルゴリズムの性能を上げようという狙いがあるようです。
なぜフルスクラッチで開発したか?
Isolation ForestはPythonのライブラリであるScikit-learnに含まれているので、一般的な機械学習の場面ではそちらを使うのがよいでしょう。今回我々はScutumに組み込むためにJavaで実装されたものが必要となりましたが、JavaのIsolation Forest実装としては以下の2つを見つけました。
- Weka
- H2O
Weka版は本当に必要な最低限の実装となっているためソースコードも短く読みやすかったのですが、データの扱いにWekaのクラス依存があったために見送りました。またH2O版はその部分だけを切り出すことができなさそうだったので、こちらも諦めました。
最初は自分で書くつもりはなかったのですが、Weka版のソースコードとIsolation Forestの論文を何となく眺めていたら、Isolation Forestがとてもシンプルなアルゴリズムだということが理解できました。そしてそのうち、これなら自分で実装することができるかもしれない、と考えるようになりました。
最終的にはアルゴリズムへの理解を深めることも目的に、またマルチスレッド実装にして高速な学習ができるようにしたい、と考えて自分で開発することにしました。
この実装の特徴
このIsolation Forest実装は以下のような特徴があります。
- 外部ライブラリへの依存がない
- 学習時にマルチスレッドで処理を行う
- 学習したモデルをPOJOとして出力できる
学習の際に使うそれぞれのデータの型は処理速度を重視してプリミティブなdouble[]で固定されているため、残念ながら他の型で定義されているデータをそのまま学習に使うことはできません。また、データセット全体はdata[]のCollectionとする必要があります。
次に2について、これは私が機械学習のアルゴリズムを実装する際によくやることですが、学習がマルチスレッド(並列)の実装になっています。そのため、今現在では当然となっている多数のCPUコアを搭載するマシンでの学習で、マシンパワーをフル活用し短時間で学習を終わらせることができるようになっています。
Isolation Forestでは、それぞれの木の学習は完全に独立して行う事ができます。そのため非常にマルチスレッド処理向きのアルゴリズムです。
学習だけでなく評価についてもマルチスレッド向きのアルゴリズムなのですが、この部分の処理は短時間で済むため、マルチスレッド化する必要がなさそうなのでシングルスレッドの実装にしました。特に、次に述べるPOJOを使うと非常に高速になります。
最後に3について、これはH2Oが教師あり学習などでサポートしている方法なのですが、出来たモデルをPOJOとして出力できるようになっています。具体的にはJavaのクラスのソースコードが出力できます。このクラスはこのIsolation Forest実装に含まれるMIFModelクラスにのみ依存する形となっていて、Javaのプロダクトへの組み込みが容易に行えるようになっています。
生成されたクラスでは学習した全てのIsolation TreeがJavaのif文として表現されており、実際にどんな内容を学んだのかを見ることができます。また、クラスファイルにコンパイルされ、さらに実行時にはJITコンパイラが機械語のレベルでチューニングすることによって、非常に高速な評価が行えるようになります。この高速化はPOJOを生成することの大きなメリットです。
小さなデータセットで学習と評価をさくっと行いたいだけの場合には、POJOを生成せず、ごく普通の機械学習ライブラリのように使うこともできます。この場合にはIsolation Forestはメモリ上にデータ構造として展開されてしまうためPOJOに比べると遅くなりますが、多くのケースでは実用上まったく問題のないレベルの速度だと思います。
使い方の例
4通りの使い方の例をそれぞれExampleとして実装してあります。最も単純な、わざわざPOJOを生成しない形での学習と評価は以下のように行います。2次元のデータの例です。
MIFModelBuilder builder = new MIFModelBuilder( 100 ); List< double[] > data = new ArrayList<>(); //inlier data.add( new double[] { 5, 51 } ); data.add( new double[] { 4, 50 } ); data.add( new double[] { 4, 60 } ); data.add( new double[] { 10, 52 } ); data.add( new double[] { 8, 48 } ); data.add( new double[] { 6, 47 } ); data.add( new double[] { 1, 59 } ); data.add( new double[] { 6, 50 } ); data.add( new double[] { 9, 52 } ); data.add( new double[] { 11, 40 } ); data.add( new double[] { 2, 42 } ); data.add( new double[] { 5, 65 } ); //outlier data.add( new double[] { 20, 10 } ); data.add( new double[] { 30, 90 } ); //学習 builder.build( data ); //学習した結果を使い、いくつかのinlierとoutlierのスコアを求める //inlier p( "inliers:" ); p( builder.getScore( new double[] { 4, 52 } ) ); p( builder.getScore( new double[] { 5, 52 } ) ); p( builder.getScore( new double[] { 4, 50 } ) ); p( builder.getScore( new double[] { 6, 51 } ) ); //outlier p( "\noutliers:" ); p( builder.getScore( new double[] { 70, 80 } ) ); p( builder.getScore( new double[] { 13, 100 } ) );上記でp()は単に標準出力に出力するだけのメソッドです。出力は次のようになります。
inliers: 0.5349488834410574 0.5407374530834734 0.5342202010231406 0.5407374530834734 outliers: 0.2869349677199582 0.3375225064324313
正常(inlier)と考えられる4つのデータの数値は0.53前後で、異常(outlier)と考えられる2つのデータは0.29や0.34程度と明らかに低いスコアとなっています。
使い方の例として実装されている4つのExampleクラスはどれもmainメソッドが実装されていて通常のJavaアプリケーションとして実行可能ですが、Example3から4への流れはant経由の実行を前提にしています。
ant ant Example3 ant Example4
と実行することで、
- 学習(Example3)
- POJOの生成(Example3)
- POJOのコンパイル(ant javacタスク)
- POJOを使った評価(Example4)
Isolation Forestの数学
個人的にはゴリゴリコードを書くのは好きな反面、データサイエンスで登場する数学は苦手としていますが、今回は何とか理解できました。
まず、Isolation Forestは上の方で説明したように、たくさんの木をランダムに生成し、あるデータについて探索し、深さの平均値を採るだけで基本的な目的を達成できます。そのため、実は一切の数学を使わずとも、アルゴリズムとしては外れ値を見つけることができます。
しかし元の論文ではこの「平均値」をできるだけデータセットに依らない、汎用的な「スコア」として表現するために2分探索木に関連する数学を使っています(※2)。具体的には、ある数のデータからランダムに生成した2分探索木における、失敗時(その木に含まれない値を検索するケース)の平均探索回数です。この数値を計算するロジックは論文内では解説されておらず、論文の引用先は既に絶版となっている古いJavaのデータ構造の本でした。そのため、なぜこうなるのかを理解するのに非常に苦労しました。
最終的にはなぜ調和級数が登場するのか(=計算のコストを抑えるために近似している)も含めて理解できましたが、結構時間がかかりました(※3)。基本的には2分探索木に関連する数学なので、この部分の知識があれば理解はスムーズだろうと思われます。
処理速度的には、わざわざ「スコアを揃える」ことを行わずに、単純に結果を深さの平均値で出してしまう方が高速です。そのため、少しでも評価を速くしたい場合には、これらの数学的な補正部分はカットしてしまうことを考えてもよいかもしれません。Javaではそれほどこの部分で速度が低下することもなさそうだったので、今回の実装は論文通りに行っています。
まとめ
今回はJavaのIsolation Forest実装について簡単に紹介しました。外部依存が少なく、マルチコアあるいはメニーコアのマシンで高速に学習することが可能になっているので、ぜひ使ってみてください。 ※1: 他の4つのデータセットでは全てIsolation Forestが良い性能でした。教師なし学習のアルゴリズム同士の性能を正確に比べるのは非常に難しい、というか原理的に無理なのですが、その点については後日別のエントリにて紹介したいと思います。 ※2: Scikit-learnのIsolation Forestは、ここのスコア補正が論文とは異なる値になっています。そのため論文通りに実装した場合とは出てくるスコアが大きく異なります。また、Scikit-learn版はある深さに到達し成長を止めたノードにおいて、残っているデータの数を記憶しており、そのノードのスコアに反映する仕組みもあります。これは主に正常なデータのスコアに反映されるものであり、異常検知にはそれほど必要でない仕組みであることから、処理速度を優先して今回の実装では取り入れませんでした。 ※3: ちょうど高校数学の復習をしていたのもシグマの計算などで大きな助けになりました。