概要
TabNet は、表形式の(構造化)データ向けのわかりやすいディープ ラーニング アーキテクチャです。これには 2 つの長所が組み合わされ、ブラックボックス モデルとアンサンブルの高精度を実現しながら、より単純なツリーベース モデルの説明可能性を備えています。そのため、TabNet は金融資産の価格予測、詐欺 / サイバー攻撃 / 犯罪の検出、小売需要の予測、医療記録からの診断、商品の推奨など、表形式データを使用するさまざまなタスクに適しています。
TabNet は、シーケンシャル アテンションという、このアーキテクチャ用に特別に設計されたレイヤを使用して、モデルの各ステップで推論の対象とするモデル特徴を選択します。このメカニズムにより、モデルが予測に到達する方法を説明し、より正確なモデルの学習を支援します。この設計により、TabNet は他のニューラル ネットワークとディシジョン ツリーより優れたパフォーマンスを実現するだけでなく、わかりやすい特徴アトリビューションを提供します。
入力データ
TabNet は、入力として次のいずれかの表形式を想定します。
- トレーニング データ: モデルのトレーニングに使用されるラベルデータ。サポートされているファイル形式は次のとおりです。
- 入力スキーマ:
- 入力が CSV の場合、最初の列はターゲット変数になります。
- 入力が BigQuery の場合は、
target_column
パラメータを指定します。
CSV ファイルを準備する
入力データには、UTF-8 エンコードの CSV ファイルを使用できます。トレーニング データがカテゴリ値と数値のみで構成されている場合は、Google の前処理モジュールを使用して欠損している数値を埋め、データセットを分割し、欠損値が 10% を超える行を削除できます。それ以外の場合は、自動前処理を有効にせずにトレーニングを行うことができます。CSV ファイルの最初の列はターゲット変数です。CSV ファイルにヘッダーがある場合は、has_header
パラメータを指定する必要があります。
BigQuery データセットを準備する
入力には BigQuery データセットを使用できます。BigQuery に入力データを読み込む方法はいくつかあります。
トレーニング
TabNet を使用して単一ノード トレーニングを実行するには、次のコマンドを使用します。このコマンドは、単一の CPU マシンを使用する CustomJob
リソースを作成します。トレーニングで使用できるフラグについては、このページのフラグをご覧ください。
CSV 入力でトレーニングを行う
入力形式として CSV を使用する場合の例を次に示します。モデルのトレーニングに成功したら、トレーニング中で使用されているハイパーパラメータを最適化して、モデルの精度とパフォーマンスを改善することをおすすめします。チュートリアル ノートブックには、ハイパーパラメータ トレーニング ジョブの例が記載されています。
# URI of the TabNet Docker image. LEARNER_IMAGE_URI='us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/tab_net_v2' # The region to run the job in. REGION='us-central1' # Your project. PROJECT_ID="[your-project-id]" # Set the training data DATASET_NAME="petfinder" # Change to your dataset name. IMPORT_FILE="petfinder-tabular-classification-tabnet-with-header.csv" MODEL_TYPE="classification" # Give a unique name to your training job. DATE="$(date '+%Y%m%d_%H%M%S')" # Set a unique name for the job to run. JOB_NAME="tab_net_cpu_${DATASET_NAME}_${DATE}" echo $JOB_NAME # Define your bucket. YOUR_BUCKET_NAME="gs://[your-bucket-name]" # Replace by your bucket name # Copy the csv to your bucket. TRAINING_DATA_PATH="${YOUR_BUCKET_NAME}/data/${DATASET_NAME}/train.csv" gsutil cp gs://cloud-samples-data/ai-platform-unified/datasets/tabular/${IMPORT_FILE} TRAINING_DATA_PATH # Set a location for the output. OUTPUT_DIR="${YOUR_BUCKET_NAME}/${JOB_NAME}/" echo $OUTPUT_DIR echo $JOB_NAME gcloud ai custom-jobs create \ --region=${REGION} \ --display-name=${JOB_NAME} \ --worker-pool-spec=machine-type=n1-standard-8,replica-count=1,container-image-uri=${LEARNER_IMAGE_URI} \ --args=--preprocess \ --args=--model_type=${MODEL_TYPE} \ --args=--data_has_header \ --args=--training_data_path=${TRAINING_DATA_PATH} \ --args=--job-dir=${OUTPUT_DIR} \ --args=--max_steps=2000 \ --args=--batch_size=4096 \ --args=--learning_rate=0.01
BigQuery 入力でトレーニングを行う
入力として BigQuery を使用する場合の例を以下に示します。
# URI of the TabNet Docker image. LEARNER_IMAGE_URI='us-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/tab_net_v2' # The region to run the job in. REGION='us-central1' # Your project. PROJECT_ID="[your-project-id]" # Set the training data DATASET_NAME="petfinder" # Change to your dataset name. IMPORT_FILE="petfinder-tabular-classification-tabnet-with-header.csv" # Give a unique name to your training job. DATE="$(date '+%Y%m%d_%H%M%S')" # Set a unique name for the job to run. JOB_NAME="tab_net_cpu_${DATASET_NAME}_${DATE}" echo $JOB_NAME # Define your bucket. YOUR_BUCKET_NAME="gs://[your-bucket-name]" # Replace by your bucket name # Copy the csv to your bucket. TRAINING_DATA_PATH="${YOUR_BUCKET_NAME}/data/${DATASET_NAME}/train.csv" gsutil cp gs://cloud-samples-data/ai-platform-unified/datasets/tabular/${IMPORT_FILE} TRAINING_DATA_PATH # Create BigQuery dataset. bq --location=${REGION} mk --dataset ${PROJECT_ID}:${DATASET_NAME} # Create BigQuery table using CSV file. TABLE_NAME="train" bq --location=${REGION} load --source_format=CSV --autodetect ${PROJECT_ID}:${DATASET_NAME}.${TABLE_NAME} ${YOUR_BUCKET_NAME}/data/petfinder/train.csv # Set a location for the output. OUTPUT_DIR="${YOUR_BUCKET_NAME}/${JOB_NAME}/" echo $OUTPUT_DIR echo $JOB_NAME gcloud ai custom-jobs create \ --region=${REGION} \ --display-name=${JOB_NAME} \ --worker-pool-spec=machine-type=n1-standard-8,replica-count=1,container-image-uri=${LEARNER_IMAGE_URI} \ --args=--preprocess \ --args=--input_type=bigquery \ --args=--model_type=classification \ --args=--stream_inputs \ --args=--bq_project=${PROJECT_ID} \ --args=--dataset_name=${DATASET_NAME} \ --args=--table_name=${TABLE_NAME} \ --args=--target_column=Adopted \ --args=--num_parallel_reads=2 \ --args=--optimizer_type=adam \ --args=--data_cache=disk \ --args=--deterministic_data=False \ --args=--loss_function_type=weighted_cross_entropy \ --args=--replace_transformed_features=True \ --args=--apply_quantile_transform=True \ --args=--apply_log_transform=True \ --args=--max_steps=2000 \ --args=--batch_size=4096 \ --args=--learning_rate=0.01 \ --args=--job-dir=${OUTPUT_DIR}
ジョブ ディレクトリについて
トレーニング ジョブが正常に完了すると、TabNet トレーニングにより、トレーニング済みモデルが他のアーティファクトと一緒に Cloud Storage バケットに作成されます。JOB_DIR
のディレクトリ構造は次のとおりです。
- artifacts/
- metadata.json
- model/(TensorFlow SavedModel ディレクトリ。
deployment_config.yaml
ファイルも格納します)- saved_model.pb
- deployment_config.yaml
- processed_data/
- test.csv
- training.csv
- validation.csv
ジョブ ディレクトリには、experiment ディレクトリ内のさまざまなモデル チェックポイント ファイルも格納されます。TensorBoard を使用して指標を可視化できます。最終的な指標も deployment_config.yaml
に含まれます。
JOB_DIR
のディレクトリ構造が上記と一致していることを確認します。
gsutil ls -a $JOB_DIR/*
チュートリアル ノートブック
Colab には、TabNet を起動するためのサンプル ノートブックが用意されています。このノートブックでは、次の使用方法も示されています。
BigQuery 入力を使用したトレーニング。
GPU を使用した分散トレーニング。
ハイパーパラメータ調整。
フラグ
モデルをトレーニングするときは、次の一般的なトレーニング フラグと TabNet 固有のトレーニング フラグを使用します。
一般的なトレーニング フラグ
よく使用されるカスタム トレーニング フラグは次のとおりです。詳細については、カスタム トレーニング ジョブを作成するをご覧ください。
worker-pool-spec
: カスタムジョブで使用されるワーカープールの構成。複数のワーカープールを含むカスタムジョブを作成するには、複数のworker-pool-spec
構成を指定します。worker-pool-spec
には、次のフィールドを含めることができます。これらのフィールドは、WorkerPoolSpec API メッセージ内で、対応するフィールドと一緒に示されます。machine-type
: プールのマシンタイプ。サポートされているマシンの一覧については、マシンタイプをご覧ください。replica-count
: プール内のマシンのレプリカの数。container-image-uri
: 各ワーカーで実行する Docker イメージ。TabNet 組み込みアルゴリズムを使用するには、Docker イメージをus-docker.pkg.dev/vertex-ai-restricted/builtin-algorithm/tab_net_v2:latest
に設定する必要があります。
display-name
: ジョブの名前。region
: ジョブを実行するリージョン。
TabNet 固有のトレーニング フラグ
次の表に、TabNet トレーニング ジョブで設定可能なランタイム パラメータを示します。
パラメータ | データ型 | 説明 | 必須 |
---|---|---|---|
preprocess |
ブール値の引数 | 自動前処理を有効にする場合は指定します。 | いいえ |
job_dir |
文字列 | モデル出力ファイルが保存される Cloud Storage ディレクトリ。 | はい |
input_metadata_path |
文字列 | トレーニング データセットの TabNet 固有のメタデータへの Cloud Storage パス。メタデータの作成方法については、上記をご覧ください。 | いいえ |
training_data_path |
文字列 | トレーニング データが格納される Cloud Storage パターン。 | はい |
validation_data_path |
文字列 | 評価データが格納される Cloud Storage パターン。 | いいえ |
test_data_path |
文字列 | テストデータが格納される Cloud Storage パターン。 | はい |
input_type |
文字列 | 「bigquery」または「csv」- 入力の表形式データのタイプ。csv を指定すると、最初の列はターゲットとして扱われます。CSV ファイルにヘッダーがある場合は、data_has_header フラグも渡します。「bigquery」を使用する場合は、トレーニング / 検証データのパスを指定できます。また、BigQuery プロジェクト、データセット、テーブル名を指定して前処理を行い、トレーニング データセットと検証データセットを作成することもできます。 | いいえ。デフォルトは「csv」です。 |
model_type |
文字列 | 分類や回帰などの学習タスク | はい |
split_column |
文字列 | トレーニング、検証、テスト分割の作成に使用される列名。列の値(table ['split_column'])には、「TRAIN」、「VALIDATE」、「TEST」のいずれかを指定する必要があります。「TEST」は省略可能です。BigQuery 入力にのみ適用されます。 | いいえ |
train_batch_size |
整数 | トレーニングのバッチサイズ。 | いいえ。デフォルトは 1024 です。 |
eval_split |
浮動小数点数 | validation_data_path が指定されていない場合に、評価データセットに使用する分割割合。 |
いいえ。デフォルトは 0.2 です。 |
learning_rate |
浮動小数点数 | トレーニングの学習率。 | いいえ。デフォルトは、指定されたオプティマイザーのデフォルトの学習率です。 |
eval_frequency_secs |
整数 | 評価とチェックポイント処理が行われる頻度。デフォルトは 600 です。 | いいえ |
num_parallel_reads |
整数 | 入力ファイルの読み取りに使用されるスレッド数。多くの場合、パフォーマンスを最大限に高めるため、マシンの CPU 数と同じにするか、若干少なく設定することをおすすめします。たとえば、デフォルトの選択肢として GPU あたり 6 個が最適です。 | はい |
data_cache |
文字列 | データのキャッシュ先を選択します。「memory」、「disk」、「no_cache」のいずれかを指定します。大規模なデータセットの場合、データのキャッシュ先をメモリにすると、メモリ不足エラーがスローされるため、「disk」を選択することをおすすめします。構成ファイルでディスクサイズを指定できます(以下の例を参照)。大規模な(B スケール)のデータセットの場合は、データの書き込みに十分な大きさ(TB サイズなど)のディスクをリクエストする必要があります。 | いいえ。デフォルトは「memory」です。 |
bq_project |
文字列 | BigQuery プロジェクトの名前。input_type=bigquery にして –preprocessing フラグを使用する場合、このフラグは必須です。トレーニング、検証、テストのデータパスを指定する代わりに、これを使用できます。 | いいえ |
dataset_name |
文字列 | BigQuery データセットの名前。input_type=bigquery にして –preprocessing フラグを使用する場合、このフラグは必須です。トレーニング、検証、テストのデータパスを指定する代わりに、これを使用できます。 | いいえ |
table_name |
文字列 | BigQuery テーブルの名前input_type=bigquery にして –preprocessing フラグを使用する場合、このフラグは必須です。トレーニング、検証、テストのデータパスを指定する代わりに、これを使用できます。 | いいえ |
loss_function_type |
文字列 | TabNet には、いくつかの損失関数型があります。回帰の場合: mse / mae。分類の場合: cross_entropy / weight_cross_entropy / focal_loss。 | いいえ。回帰の場合、デフォルトは「mse」です。分類の場合、デフォルトは「cross_entropy」です。 |
deterministic_data |
ブール値の引数 | 表形式データからのデータ読み取りの決定論。デフォルトは False に設定されています。True に設定すると、テストは確定的になります。大規模なデータセットで高速トレーニングを行う場合は、deterministic_data=False 設定をおすすめします。結果はランダムに決まりますが、大規模なデータセットでは無視できる程度です。有限精度の代数演算の順序付けのため、map-reduce ではランダム性が生じるため、分散トレーニングでは決定論は保証されません。ただし、大規模なデータセットでは無視できる程度です。100% の決定論が求められる場合は、deterministic_data=True を設定するだけでなく、単一 GPU を使用するトレーニング(MACHINE_TYPE="n1-highmem-8" など)をおすすめします。 | いいえ。デフォルトは False です。 |
stream_inputs |
ブール値の引数 | 入力データをローカルにダウンロードするのではなく、Cloud Storage からストリーミングします。高速実行のため、このオプションを使用することをおすすめします。 | いいえ |
large_category_dim |
整数 | 埋め込みの次元数。1 つのカテゴリ列の個別のカテゴリ数が big_category_thresh よりも大きい場合、1 次元の埋め込みではなく、large_category_dim 次元の埋め込みを使用します。デフォルトは 1 です。計算効率や説明可能性ではなく、精度の向上が主な目標である場合は、この数値を増やすことをおすすめします(たとえば、通常の場合は 5 程度。データセット内のカテゴリ数が非常に多い場合は、10 程度)。 | いいえ。デフォルトは 1 です。 |
large_category_thresh |
整数 | カテゴリ列の基数しきい値。1 つのカテゴリ列の個別のカテゴリ数が big_category_thresh よりも大きい場合、1 次元の埋め込みではなく、large_category_dim 次元の埋め込みを使用します。デフォルト値は 300 です。計算効率や説明可能性ではなく、精度の向上が主な目標である場合は、数値を下げることをおすすめします(たとえば 10 程度)。 | いいえ。デフォルトは 300 です。 |
yeo_johnson_transform |
ブール値の引数 | トレーニング可能な Yeo-Johnson Power Transform を有効にします(デフォルトは無効)。Yeo-Johnson Power Transform の詳細については、https://www.stat.umn.edu/arc/yjpower.pdf をご覧ください。この実装では、変換パラメータは TabNet とともにエンドツーエンドのトレーニングで学習可能です。 | いいえ |
apply_log_transform |
ブール値の引数 | メタデータにログ変換の統計情報が含まれていて、このフラグが true の場合、入力特徴はログ変換されます。変換を使用しない場合は false を使用し、変換を使用する場合は true(デフォルト)を使用します。特に、数値分布が偏ったデータセットの場合は、ログ変換が非常に役立つ可能性があります。 | いいえ |
apply_quantile_transform |
ブール値の引数 | メタデータに分位点の統計情報が含まれていて、このフラグが true の場合、入力特徴は分位点変換されます。変換を使用しない場合は false を使用し、変換を使用する場合は true(デフォルト)を使用します。特に、数値分布が偏ったデータセットの場合は、分位点変換が非常に役立つ可能性があります。現在、BigQuery input_type でサポートされています。 | いいえ |