VertexAI で TabNet 組み込みアルゴリズムを使ってみる

概要

TabNet は、表形式の(構造化)データ向けのわかりやすいディープ ラーニング アーキテクチャです。これには 2 つの長所が組み合わされ、ブラックボックス モデルとアンサンブルの高精度を実現しながら、より単純なツリーベース モデルの説明可能性を備えています。そのため、TabNet は金融資産の価格予測、詐欺 / サイバー攻撃 / 犯罪の検出、小売需要の予測、医療記録からの診断、商品の推奨など、表形式データを使用するさまざまなタスクに適しています。

TabNet は、シーケンシャル アテンションという、このアーキテクチャ用に特別に設計されたレイヤを使用して、モデルの各ステップで推論の対象とするモデル特徴を選択します。このメカニズムにより、モデルが予測に到達する方法を説明し、より正確なモデルの学習を支援します。この設計により、TabNet は他のニューラル ネットワークとディシジョン ツリーより優れたパフォーマンスを実現するだけでなく、わかりやすい特徴アトリビューションを提供します。

入力データ

TabNet は、入力として次のいずれかの表形式を想定します。

  • トレーニング データ: モデルのトレーニングに使用されるラベルデータ。サポートされているファイル形式は次のとおりです。
  • 入力スキーマ:
    • 入力が CSV の場合、最初の列はターゲット変数になります。
    • 入力が BigQuery の場合は、target_column パラメータを指定します。

CSV ファイルを準備する

入力データには、UTF-8 エンコードの CSV ファイルを使用できます。トレーニング データがカテゴリ値と数値のみで構成されている場合は、Google の前処理モジュールを使用して欠損している数値を埋め、データセットを分割し、欠損値が 10% を超える行を削除できます。それ以外の場合は、自動前処理を有効にせずにトレーニングを行うことができます。CSV ファイルの最初の列はターゲット変数です。CSV ファイルにヘッダーがある場合は、has_header パラメータを指定する必要があります。

BigQuery データセットを準備する

入力には BigQuery データセットを使用できます。BigQuery に入力データを読み込む方法はいくつかあります。

トレーニング

TabNet を使用して単一ノード トレーニングを実行するには、次のコマンドを使用します。このコマンドは、単一の CPU マシンを使用する CustomJob リソースを作成します。トレーニングで使用できるフラグについては、このページのフラグをご覧ください。

CSV 入力でトレーニングを行う

入力形式として CSV を使用する場合の例を次に示します。モデルのトレーニングに成功したら、トレーニング中で使用されているハイパーパラメータを最適化して、モデルの精度とパフォーマンスを改善することをおすすめします。チュートリアル ノートブックには、ハイパーパラメータ トレーニング ジョブの例が記載されています。

# URI of the TabNet Docker image.
LEARNER_IMAGE_URI='us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/tab_net_v2'

# The region to run the job in.
REGION='us-central1'

# Your project.
PROJECT_ID="[your-project-id]"

# Set the training data
DATASET_NAME="petfinder"  # Change to your dataset name.
IMPORT_FILE="petfinder-tabular-classification-tabnet-with-header.csv"
MODEL_TYPE="classification"

# Give a unique name to your training job.
DATE="$(date '+%Y%m%d_%H%M%S')"

# Set a unique name for the job to run.
JOB_NAME="tab_net_cpu_${DATASET_NAME}_${DATE}"
echo $JOB_NAME

# Define your bucket.
YOUR_BUCKET_NAME="gs://[your-bucket-name]" # Replace by your bucket name

# Copy the csv to your bucket.
TRAINING_DATA_PATH="${YOUR_BUCKET_NAME}/data/${DATASET_NAME}/train.csv"
gsutil cp gs://cloud-samples-data/ai-platform-unified/datasets/tabular/${IMPORT_FILE} TRAINING_DATA_PATH

# Set a location for the output.
OUTPUT_DIR="${YOUR_BUCKET_NAME}/${JOB_NAME}/"

echo $OUTPUT_DIR
echo $JOB_NAME

gcloud ai custom-jobs create \
  --region=${REGION} \
  --display-name=${JOB_NAME} \
  --worker-pool-spec=machine-type=n1-standard-8,replica-count=1,container-image-uri=${LEARNER_IMAGE_URI} \
  --args=--preprocess \
  --args=--model_type=${MODEL_TYPE} \
  --args=--data_has_header \
  --args=--training_data_path=${TRAINING_DATA_PATH} \
  --args=--job-dir=${OUTPUT_DIR} \
  --args=--max_steps=2000 \
  --args=--batch_size=4096 \
  --args=--learning_rate=0.01

BigQuery 入力でトレーニングを行う

入力として BigQuery を使用する場合の例を以下に示します。

# URI of the TabNet Docker image.
LEARNER_IMAGE_URI='us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/tab_net_v2'

# The region to run the job in.
REGION='us-central1'

# Your project.
PROJECT_ID="[your-project-id]"

# Set the training data
DATASET_NAME="petfinder"  # Change to your dataset name.
IMPORT_FILE="petfinder-tabular-classification-tabnet-with-header.csv"

# Give a unique name to your training job.
DATE="$(date '+%Y%m%d_%H%M%S')"

# Set a unique name for the job to run.
JOB_NAME="tab_net_cpu_${DATASET_NAME}_${DATE}"
echo $JOB_NAME

# Define your bucket.
YOUR_BUCKET_NAME="gs://[your-bucket-name]" # Replace by your bucket name

# Copy the csv to your bucket.
TRAINING_DATA_PATH="${YOUR_BUCKET_NAME}/data/${DATASET_NAME}/train.csv"
gsutil cp gs://cloud-samples-data/ai-platform-unified/datasets/tabular/${IMPORT_FILE} TRAINING_DATA_PATH

# Create BigQuery dataset.
bq --location=${REGION} mk --dataset ${PROJECT_ID}:${DATASET_NAME}

# Create BigQuery table using CSV file.
TABLE_NAME="train"
bq --location=${REGION} load --source_format=CSV --autodetect ${PROJECT_ID}:${DATASET_NAME}.${TABLE_NAME} ${YOUR_BUCKET_NAME}/data/petfinder/train.csv

# Set a location for the output.
OUTPUT_DIR="${YOUR_BUCKET_NAME}/${JOB_NAME}/"
echo $OUTPUT_DIR
echo $JOB_NAME

gcloud ai custom-jobs create \
  --region=${REGION} \
  --display-name=${JOB_NAME} \
  --worker-pool-spec=machine-type=n1-standard-8,replica-count=1,container-image-uri=${LEARNER_IMAGE_URI} \
  --args=--preprocess \
  --args=--input_type=bigquery \
  --args=--model_type=classification \
  --args=--stream_inputs \
  --args=--bq_project=${PROJECT_ID} \
  --args=--dataset_name=${DATASET_NAME} \
  --args=--table_name=${TABLE_NAME} \
  --args=--target_column=Adopted \
  --args=--num_parallel_reads=2 \
  --args=--optimizer_type=adam \
  --args=--data_cache=disk \
  --args=--deterministic_data=False \
  --args=--loss_function_type=weighted_cross_entropy \
  --args=--replace_transformed_features=True \
  --args=--apply_quantile_transform=True \
  --args=--apply_log_transform=True \
  --args=--max_steps=2000 \
  --args=--batch_size=4096 \
  --args=--learning_rate=0.01 \
  --args=--job-dir=${OUTPUT_DIR}

ジョブ ディレクトリについて

トレーニング ジョブが正常に完了すると、TabNet トレーニングにより、トレーニング済みモデルが他のアーティファクトと一緒に Cloud Storage バケットに作成されます。JOB_DIR のディレクトリ構造は次のとおりです。

  • artifacts/
    • metadata.json
  • model/(TensorFlow SavedModel ディレクトリdeployment_config.yaml ファイルも格納します)
    • saved_model.pb
    • deployment_config.yaml
  • processed_data/
    • test.csv
    • training.csv
    • validation.csv

ジョブ ディレクトリには、experiment ディレクトリ内のさまざまなモデル チェックポイント ファイルも格納されます。TensorBoard を使用して指標を可視化できます。最終的な指標も deployment_config.yaml に含まれます。

JOB_DIR のディレクトリ構造が上記と一致していることを確認します。

gsutil ls -a $JOB_DIR/*

チュートリアル ノートブック

Colab には、TabNet を起動するためのサンプル ノートブックが用意されています。このノートブックでは、次の使用方法も示されています。

  • BigQuery 入力を使用したトレーニング。

  • GPU を使用した分散トレーニング。

  • ハイパーパラメータ調整。

フラグ

モデルをトレーニングするときは、次の一般的なトレーニング フラグと TabNet 固有のトレーニング フラグを使用します。

一般的なトレーニング フラグ

よく使用されるカスタム トレーニング フラグは次のとおりです。詳細については、カスタム トレーニング ジョブを作成するをご覧ください。

  • worker-pool-spec: カスタムジョブで使用されるワーカープールの構成。複数のワーカープールを含むカスタムジョブを作成するには、複数の worker-pool-spec 構成を指定します。

    worker-pool-spec には、次のフィールドを含めることができます。これらのフィールドは、WorkerPoolSpec API メッセージ内で、対応するフィールドと一緒に示されます。

    • machine-type: プールのマシンタイプ。サポートされているマシンの一覧については、マシンタイプをご覧ください。
    • replica-count: プール内のマシンのレプリカの数。
    • container-image-uri: 各ワーカーで実行する Docker イメージ。TabNet 組み込みアルゴリズムを使用するには、Docker イメージを us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/tab_net_v2:latest に設定する必要があります。
  • display-name: ジョブの名前。

  • region: ジョブを実行するリージョン。

TabNet 固有のトレーニング フラグ

次の表に、TabNet トレーニング ジョブで設定可能なランタイム パラメータを示します。

パラメータ データ型 説明 必須
preprocess ブール値の引数 自動前処理を有効にする場合は指定します。 いいえ
job_dir 文字列 モデル出力ファイルが保存される Cloud Storage ディレクトリ。 はい
input_metadata_path 文字列 トレーニング データセットの TabNet 固有のメタデータへの Cloud Storage パス。メタデータの作成方法については、上記をご覧ください。 いいえ
training_data_path 文字列 トレーニング データが格納される Cloud Storage パターン。 はい
validation_data_path 文字列 評価データが格納される Cloud Storage パターン。 いいえ
test_data_path 文字列 テストデータが格納される Cloud Storage パターン。 はい
input_type 文字列 「bigquery」または「csv」- 入力の表形式データのタイプ。csv を指定すると、最初の列はターゲットとして扱われます。CSV ファイルにヘッダーがある場合は、data_has_header フラグも渡します。「bigquery」を使用する場合は、トレーニング / 検証データのパスを指定できます。また、BigQuery プロジェクト、データセット、テーブル名を指定して前処理を行い、トレーニング データセットと検証データセットを作成することもできます。 いいえ。デフォルトは「csv」です。
model_type 文字列 分類や回帰などの学習タスク はい
split_column 文字列 トレーニング、検証、テスト分割の作成に使用される列名。列の値(table ['split_column'])には、「TRAIN」、「VALIDATE」、「TEST」のいずれかを指定する必要があります。「TEST」は省略可能です。BigQuery 入力にのみ適用されます。 いいえ
train_batch_size 整数 トレーニングのバッチサイズ。 いいえ。デフォルトは 1024 です。
eval_split 浮動小数点数 validation_data_path が指定されていない場合に、評価データセットに使用する分割割合。 いいえ。デフォルトは 0.2 です。
learning_rate 浮動小数点数 トレーニングの学習率。 いいえ。デフォルトは、指定されたオプティマイザーのデフォルトの学習率です。
eval_frequency_secs 整数 評価とチェックポイント処理が行われる頻度。デフォルトは 600 です。 いいえ
num_parallel_reads 整数 入力ファイルの読み取りに使用されるスレッド数。多くの場合、パフォーマンスを最大限に高めるため、マシンの CPU 数と同じにするか、若干少なく設定することをおすすめします。たとえば、デフォルトの選択肢として GPU あたり 6 個が最適です。 はい
data_cache 文字列 データのキャッシュ先を選択します。「memory」、「disk」、「no_cache」のいずれかを指定します。大規模なデータセットの場合、データのキャッシュ先をメモリにすると、メモリ不足エラーがスローされるため、「disk」を選択することをおすすめします。構成ファイルでディスクサイズを指定できます(以下の例を参照)。大規模な(B スケール)のデータセットの場合は、データの書き込みに十分な大きさ(TB サイズなど)のディスクをリクエストする必要があります。 いいえ。デフォルトは「memory」です。
bq_project 文字列 BigQuery プロジェクトの名前。input_type=bigquery にして –preprocessing フラグを使用する場合、このフラグは必須です。トレーニング、検証、テストのデータパスを指定する代わりに、これを使用できます。 いいえ
dataset_name 文字列 BigQuery データセットの名前。input_type=bigquery にして –preprocessing フラグを使用する場合、このフラグは必須です。トレーニング、検証、テストのデータパスを指定する代わりに、これを使用できます。 いいえ
table_name 文字列 BigQuery テーブルの名前input_type=bigquery にして –preprocessing フラグを使用する場合、このフラグは必須です。トレーニング、検証、テストのデータパスを指定する代わりに、これを使用できます。 いいえ
loss_function_type 文字列 TabNet には、いくつかの損失関数型があります。回帰の場合: mse / mae。分類の場合: cross_entropy / weight_cross_entropy / focal_loss。 いいえ。回帰の場合、デフォルトは「mse」です。分類の場合、デフォルトは「cross_entropy」です。
deterministic_data ブール値の引数 表形式データからのデータ読み取りの決定論。デフォルトは False に設定されています。True に設定すると、テストは確定的になります。大規模なデータセットで高速トレーニングを行う場合は、deterministic_data=False 設定をおすすめします。結果はランダムに決まりますが、大規模なデータセットでは無視できる程度です。有限精度の代数演算の順序付けのため、map-reduce ではランダム性が生じるため、分散トレーニングでは決定論は保証されません。ただし、大規模なデータセットでは無視できる程度です。100% の決定論が求められる場合は、deterministic_data=True を設定するだけでなく、単一 GPU を使用するトレーニング(MACHINE_TYPE="n1-highmem-8" など)をおすすめします。 いいえ。デフォルトは False です。
stream_inputs ブール値の引数 入力データをローカルにダウンロードするのではなく、Cloud Storage からストリーミングします。高速実行のため、このオプションを使用することをおすすめします。 いいえ
large_category_dim 整数 埋め込みの次元数。1 つのカテゴリ列の個別のカテゴリ数が big_category_thresh よりも大きい場合、1 次元の埋め込みではなく、large_category_dim 次元の埋め込みを使用します。デフォルトは 1 です。計算効率や説明可能性ではなく、精度の向上が主な目標である場合は、この数値を増やすことをおすすめします(たとえば、通常の場合は 5 程度。データセット内のカテゴリ数が非常に多い場合は、10 程度)。 いいえ。デフォルトは 1 です。
large_category_thresh 整数 カテゴリ列の基数しきい値。1 つのカテゴリ列の個別のカテゴリ数が big_category_thresh よりも大きい場合、1 次元の埋め込みではなく、large_category_dim 次元の埋め込みを使用します。デフォルト値は 300 です。計算効率や説明可能性ではなく、精度の向上が主な目標である場合は、数値を下げることをおすすめします(たとえば 10 程度)。 いいえ。デフォルトは 300 です。
yeo_johnson_transform ブール値の引数 トレーニング可能な Yeo-Johnson Power Transform を有効にします(デフォルトは無効)。Yeo-Johnson Power Transform の詳細については、https://www.stat.umn.edu/arc/yjpower.pdf をご覧ください。この実装では、変換パラメータは TabNet とともにエンドツーエンドのトレーニングで学習可能です。 いいえ
apply_log_transform ブール値の引数 メタデータにログ変換の統計情報が含まれていて、このフラグが true の場合、入力特徴はログ変換されます。変換を使用しない場合は false を使用し、変換を使用する場合は true(デフォルト)を使用します。特に、数値分布が偏ったデータセットの場合は、ログ変換が非常に役立つ可能性があります。 いいえ
apply_quantile_transform ブール値の引数 メタデータに分位点の統計情報が含まれていて、このフラグが true の場合、入力特徴は分位点変換されます。変換を使用しない場合は false を使用し、変換を使用する場合は true(デフォルト)を使用します。特に、数値分布が偏ったデータセットの場合は、分位点変換が非常に役立つ可能性があります。現在、BigQuery input_type でサポートされています。 いいえ

次のステップ