コンテンツに移動
AI & 機械学習

TabNet on Vertex AI の改善: 高パフォーマンスでスケーラブルな表形式ディープ ラーニング

2022年11月4日
Google Cloud Japan Team

※この投稿は米国時間 2022 年 10 月 27 日に、Google Cloud blog に投稿されたものの抄訳です。

データ サイエンティストは、企業内で最も一般的なデータ型である表形式(つまり、構造化)データを含む機械学習(ML)問題を解く際に、さまざまなトレードオフに基づいてモデルを選択します。その中でもディシジョン ツリーは解釈が容易でトレーニング速度が速く、小規模なデータセットから迅速に高い精度を得られるため、人気のあるモデルです。一方、ディープ ニューラル ネットワークは、より大規模なデータセットでの精度に優れ、エンドツーエンド学習のメリットもありますが、ブラックボックス的で解釈が難しいという側面もあります。Google AI が開発したわかりやすいディープ ラーニング アーキテクチャである TabNet は、単純なツリーベース モデルのように説明可能であり、複雑なブラックボックス モデルとアンサンブルの高い精度を実現できるという、両方の長所を兼ね備えています。

この度、その TabNet が Vertex AI Tabular Workflows で利用できるようになりました。Tabular Workflows の最適化されたスケーラブルなフルマネージド パイプラインでは、実装の詳細を気にせずに TabNet を使用できます。また、Vertex の MLOps 機能を使用して TabNet を簡単にデプロイできます。TabNet on Vertex AI Tabular Workflows は、膨大な表形式データセットに効率的にスケーリングできるように最適化されています。また、オリジナルの TabNet に機械学習の改良を加え、実世界のデータ課題に対してより高い精度を実現しています。

TabNet on Vertex AI は、金融資産の価格予測、詐欺 / サイバー攻撃 / 犯罪の検出、小売需要の予測、ユーザー モデリング、信用 / リスク評価、医療記録からの診断、商品の推奨など、精度と同様にモデルの説明可能性が重要な表形式データを使用するさまざまなタスクに適しています。

Tabnet の概要

TabNet は、シーケンシャル アテンションに基づいて各ステップで推論元となるモデル特徴を選択する、特別に設計されたアーキテクチャ(図 1 の概要を参照)を備えています。このメカニズムにより、モデルがどのように予測に至るのかを説明することが可能になり、考え抜かれた設計で優れた精度を実現しています。TabNet は、他のモデル(ニューラル ネットワークやディシジョン ツリーなど)よりも優れた性能を発揮するだけでなく、解釈可能な特徴アトリビューションも提供します。学術的なベンチマークでの結果など、詳細は AAAI 2021 の論文に記載されています。
https://storage.googleapis.com/gweb-cloudblog-publish/images/Figure_1_jCXvAoR.max-2000x2000.jpg
図 1: TabNet アーキテクチャ

TabNet は発表以来、多種多様な業界のさまざまな企業や、価値の高い表形式データのアプリケーションから大きな支持を得ています(その多くには、ディープ ラーニングが先験的に使われてさえいなかったものも含まれます)。MicrosoftLudwigRavelinDetermined など、数多くの企業で採用されています。TabNet に対するお客様の関心度が高いことから、実際のディープ ラーニングの開発と運用化のニーズを考慮して Vertex で利用できるようにするとともに、パフォーマンスと効率性の向上に取り組んできました。

TabNet on Vertex AI Tabular Workflows のハイライト

超大規模データセットへのスケーリング

BigQuery などのクラウド テクノロジーの進歩により、企業はますます多くの表形式データを収集するようになり、数十億のサンプルと数百 / 数千の特徴を持つデータセットが標準になりつつあります。一般的に、ディープ ラーニング モデルは、予測を促進する複雑なパターンをより適切に学習できるため、より多くのデータサンプルとより多くの特徴から最適な方法で優れた学習成果を得ることが可能です。しかし、膨大なデータに対するモデル開発を考えると、コンピューティングが大きな課題となってきます。そのため、モデル開発には高額な費用と長い時間がかかり、多くのお客様が大規模なデータセットを十分に活用するうえでのボトルネックとなっています。TabNet on Tabular Workflows を使用すると、非常に大きな表形式データセットへのスケーリングをより効率的に行えるようになります。

実装の重要な側面: TabNet のアーキテクチャは、テンソル代数演算を中心に構成され、非常に大きなバッチサイズを利用しており、高い計算強度を持つ(転送される各データバイトに対して多くの演算を行う)など、スケーリングに関する独自のメリットがあります。これにより、多数の GPU を使用した効率的な分散トレーニングが可能となり、Google が改良を加えた実装で TabNet トレーニングをスケーリングするために活用されています。

TabNet on Vertex AI Tabular Workflows では、ユーザーが Vertex AI にかける費用を最大限に回収できるように、ハードウェア使用率を最大化するためのデータとトレーニングのパイプラインが慎重に設計されています。次の機能により、TabNet on Tabular Workflows でのスケーリングが可能になります。

  • Tensorflow のベスト プラクティスを反映し、分散トレーニングにおいて GPU 使用率を最大化するように最適化されたパイプラインで、複数の CPU により行われるデータの並列読み込み。

  • コンピューティング要件の高い大規模データセットで大幅な高速化を実現する、複数 GPU でのトレーニング。ユーザーは複数の GPU を搭載した GCP 上で利用可能な任意のマシンを指定でき、分散トレーニングを使用してそのマシン上で自動的にモデルを実行します。

  • 分散トレーニングで効率的なデータ並列処理を実現するために、Tensorflow ミラーリング分散戦略を使用して、多くの GPU にまたがるデータ並列処理をサポートしています。その結果、100~1,000 の特徴を持つ 10 億規模のデータセットにおいて、複数の GPU を使用して 80% 超の使用率を実証しています。

ディープ ラーニング モデルの標準的な実装では、GPU の使用率が低くなり、リソースを効率的に使用できない可能性があります。TabNet on Vertext という Google の実装では、大規模データセットへの費用に対して最大限の効果が得られます。

実際の顧客データでの例: 大規模なデータセットを使用しており、高速なトレーニングが重要である企業のユースケースに特化したトレーニング時間のベンチマークを実施しました。ある代表的な例では、1 つの NVIDIA_TESLA_V100 GPU を使用して、約 500 万サンプルのデータセットに対して、約 1 時間で最先端のパフォーマンスを達成しました。別の例では、4 つの NVIDIA_TESLA_V100 GPU を使用して、約 14 億サンプルのデータセットに対して、約 14 時間で最先端のパフォーマンスを達成しました。

実世界のデータ課題を考慮した精度の向上

TabNet on Vertex AI Tabular Workflows は、オリジナル版と比較して機械学習機能が向上しています。特に、実世界でよく見られる表形式データの課題に焦点を当てています。実世界の表形式データの一般的な課題として、偏った分布を持つ数値列が挙げられます。この課題に対して、Google は TabNet の学習を向上させる学習可能な前処理レイヤ(たとえば、パラメータ化された累乗変換族や分位点変換など)を生成しました。また、カテゴリデータのカテゴリ数が多いという共通課題もありますが、これに対しては、調整可能な高次元埋め込みを採用しています。もう一つの課題は、ラベル分布の不均衡です。これに対しては、さまざまな損失関数族(Focal Loss や微分可能な AUC バリアントなど)を追加しました。これらの追加により、一部の例ではパフォーマンスが大幅に向上することが確認されています。

実際の顧客データを使用したケーススタディ: 大手企業のお客様と協力し、推奨、ランキング、不正行為の検出、到着予定時刻予測など、幅広いユースケースで従来のアルゴリズムを TabNet に置き換えた事例をご紹介します。代表的な例では、TabNet が大手顧客向けの高度なモデル アンサンブルに匹敵する結果となりました。ほとんどのケースでアンサンブルを上回り、いくつかの重要なタスクでは 10% 近くのエラー削減につながりました。このモデルで 1% 改善するごとに数百万ドルの節約につながったことを考えると、これは素晴らしい結果と言えます。

すぐに使える説明可能性

精度の高さに加えて、TabNet のもう一つの主なメリットは、多層パーセプトロンのような従来のディープ ニューラル ネットワーク(DNN)モデルとは異なり、そのアーキテクチャに最初から説明可能性が含まれていることです。この Vertex Tabular Workflows の新機能により、トレーニング済みの TabNet モデルの説明を可視化しやすくなり、TabNet モデルがどのようにその決定に至ったのかを素早く把握することが可能になりました。TabNet は学習済みのマスクを介して特徴重要度の出力を提供します。このマスクは、モデルの特定の決定ステップにおいて、ある特徴が選択されているかどうかを示すものです。以下は、マスク値に基づくローカルおよびグローバルな特徴の重要度を可視化したものです。特定のサンプルのマスク値が高いほど、そのサンプルにおいて対応する特徴の重要度が高くなります。TabNet の説明可能性は、Shapley 値のような推定に計算コストがかかる事後的な方法に対して、モデルの中間層から容易に説明を得られるという基本的なメリットがあります。さらに、事後的な説明は非線形ブラックボックス関数の近似に基づいているのに対し、TabNet の説明は実際の意思決定が何に基づいているかを基準としています。

説明可能性の例: この種の説明可能性により、どのようなことが達成できるかを説明していきます。以下の図 2 では国勢調査データセットの特徴の重要度を示しています。この図から、年収 5 万ドル超を達成できるかを予測するうえで最も重要な特徴は、教育、職業、1 週間当たりの労働時間であることがわかります(対応する列の色が明るくなっています)。説明可能性の機能はサンプル単位であり、各サンプルの特徴の重要度を個別に取得できます。

https://storage.googleapis.com/gweb-cloudblog-publish/images/Figure_2_G4v3h51.max-800x800.jpg
図 2: グローバル インスタンス単位の特徴選択を示す国勢調査データの合計特徴重要度マスク。色が明るいほど高い値を示します。各行は、各入力インスタンスのマスクを表します。この図には、30 個の入力インスタンスの出力マスクがあります。各列は 1 つの特徴を表します。たとえば、最初の列は国勢調査データの年齢特徴を表し、2 番目の列は職業分類特徴を表します。この図は、教育、職業、1 週間あたりの労働時間が最も重要な特徴である(対応する列の網掛けが明るい)ことを示しています。

フルマネージド Vertex Pipelines のメリット

TabNet on Vertex Tabular Workflows を活用すると、モデルの開発とデプロイのタスクが非常に簡単になります。コードを記述することなく、トレーニング済みの TabNet モデルを取得してアプリケーションにデプロイし、Vertex のマネージド パイプラインで可能になる MLOps 機能を使用できます。これらのメリットを一部ご紹介します。  

  • Vertex AI PipelinesVertex AI Experiments などのプロダクトを含む、自動化された ML を大規模に実装するための Vertex AI ML Ops との互換性。

  • デプロイの利便性: Vertex AI 予測サービスでは、バッチモードとオンライン モードの両方が最初からサポートされています。

  • ユーザーのドメイン知識を最大限に活用できる、カスタマイズ可能な特徴量エンジニアリング。

  • Google の最先端の検索アルゴリズムを使用して、最適なハイパーパラメータを特定するための自動チューニングを行い、データセットのサイズ、予測タイプ、トレーニング予算に基づいて適切なハイパーパラメータの検索空間を自動的に選択します。

  • デプロイされたモデルのトラッキングと便利な評価ツール。

  • ユーザー ジャーニーが統合され、他のモデル(AutoMLWide & Deep Networks など)との比較ベンチマークが容易になります。

  • 国際的なワークロードに対応するためのマルチリージョン可用性。

詳細

お使いの表形式データセットで TabNet on Vertex AI を試してみたい場合は、Vertex AI の表形式ワークフローをご確認のうえ、フォームにご記入ください。


謝辞: このブログ投稿にご助力いただいた Nate Yoder、Yihe Dong、Dawei Jia、Alex Martin、Helin Wang、Henry Tappen、Tomas Pfister の各氏に感謝の意を表します。

- Long T. Le(Google Cloud AI ソフトウェア エンジニア)
- Sercan Ö. Arik(Google Cloud AI リサーチ サイエンティスト

投稿先