TabNet でモデルをトレーニングする

このページでは、TabNet の表形式のワークフローを使用して表形式のデータセットから分類モデルまたは回帰モデルをトレーニングする方法について説明します。

TabNet 用の表形式のワークフローは、2 つのバージョンで使用できます。

  • HyperparameterTuningJob は、モデルのトレーニングに使用するハイパーパラメータ値の最適な組み合わせを検索します。
  • CustomJob では、モデルのトレーニングに使用するハイパーパラメータ値を指定できます。必要なハイパーパラメータ値が正確にわかっている場合は、それらを検索する代わりに指定するとトレーニング リソースを節約できます。

このワークフローで使用されるサービス アカウントについては、表形式ワークフローのサービス アカウントをご覧ください。

ワークフローの API

このワークフローでは、次の API を使用します。

  • Vertex AI
  • Dataflow
  • Compute Engine
  • Cloud Storage

HyperparameterTuningJob を使用してモデルをトレーニングする

次のサンプルコードは、HyperparameterTuningJob パイプラインを実行する方法を示しています。

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

pipeline_job.run() でオプションの service_account パラメータを使用すると、Vertex AI Pipelines サービス アカウントを任意のアカウントに設定できます。

パイプラインとパラメータ値は、次の関数で定義されます。トレーニング データは、Cloud Storage の CSV ファイルか、BigQuery のテーブルのいずれかです。

template_path, parameter_values =  automl_tabular_utils.get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters(...)

get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters パラメータのサブセットは次のとおりです。

パラメータ名 定義
data_source_csv_filenames 文字列 Cloud Storage に保存されている CSV の URI。
data_source_bigquery_table_path 文字列 BigQuery テーブルの URI。
dataflow_service_account 文字列 (省略可)Dataflow ジョブを実行するカスタム サービス アカウント。プライベート IP と特定の VPC サブネットを使用するように Dataflow ジョブを構成できます。このパラメータは、デフォルトの Dataflow ワーカー サービス アカウントのオーバーライドとして機能します。
study_spec_parameters_override List[Dict[String, Any]] (省略可)ハイパーパラメータをチューニングするためのオーバーライド。このパラメータは、空にすることも、1 つ以上のハイパーパラメータを含めることもできます。ハイパーパラメータ値が設定されていない場合、Vertex AI はハイパーパラメータにデフォルトのチューニング範囲を使用します。

study_spec_parameters_override パラメータを使用してハイパーパラメータを構成する場合は、Vertex AI のヘルパー関数 get_tabnet_study_spec_parameters_override を使用できます。この関数には以下の入力があります。

  • dataset_size_bucket: データセット サイズのバケット
    • 「small」: 100 万行未満
    • 「medium」: 100 万~1 億行
    • 「large」: 1 億行超
  • training_budget_bucket: トレーニング予算のバケット
    • 「small」: $600 未満
    • 「medium」: $600~$2,400
    • 「large」: $2,400 超
  • prediction_type: 目的の予測タイプ

get_tabnet_study_spec_parameters_override 関数は、ハイパーパラメータと範囲のリストを返します。

get_tabnet_study_spec_parameters_override 関数の使用方法の例を次に示します。

study_spec_parameters_override = automl_tabular_utils.get_tabnet_study_spec_parameters_override(
    dataset_size_bucket="small",
    prediction_type="classification",
    training_budget_bucket="small",
)

CustomJob を使用してモデルをトレーニングする

次のサンプルコードは、CustomJob パイプラインを実行する方法を示しています。

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

pipeline_job.run() でオプションの service_account パラメータを使用すると、Vertex AI Pipelines サービス アカウントを任意のアカウントに設定できます。

パイプラインとパラメータ値は、次の関数で定義されます。トレーニング データは、Cloud Storage の CSV ファイルか、BigQuery のテーブルのいずれかです。

template_path, parameter_values = automl_tabular_utils.get_tabnet_trainer_pipeline_and_parameters(...)

get_tabnet_trainer_pipeline_and_parameters パラメータのサブセットは次のとおりです。

パラメータ名 定義
data_source_csv_filenames 文字列 Cloud Storage に保存されている CSV の URI。
data_source_bigquery_table_path 文字列 BigQuery テーブルの URI。
dataflow_service_account 文字列 (省略可)Dataflow ジョブを実行するカスタム サービス アカウント。プライベート IP と特定の VPC サブネットを使用するように Dataflow ジョブを構成できます。このパラメータは、デフォルトの Dataflow ワーカー サービス アカウントのオーバーライドとして機能します。

次のステップ

分類モデルまたは回帰モデルで予測を行う準備ができたら、次の 2 つのオプションがあります。