使用 TabNet 训练模型

本页面介绍了如何使用 TabNet 的表格工作流训练表格数据集中的分类或回归模型。

TabNet 的表格工作流有两个版本:

  • HyperparameterTuningJob 可搜索用于模型训练的最佳超参数值集。
  • CustomJob 可让您指定用于模型训练的超参数值。如果您确切地知道需要哪些超参数值,则可以指定它们(而不是进行搜索),从而节省训练资源。

如需了解此工作流使用的服务账号,请参阅表格工作流的服务账号

Workflow API

此工作流使用以下 API:

  • Vertex AI
  • Dataflow
  • Compute Engine
  • Cloud Storage

使用 HyperparameterTuningJob 训练模型

以下示例代码演示了如何运行 HyperparameterTuningJob 流水线:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

您可以使用 pipeline_job.run() 中的可选 service_account 参数,将 Vertex AI Pipelines 服务账号设置为您选择的账号。

流水线和参数值由以下函数定义。训练数据可以是 Cloud Storage 中的 CSV 文件,也可以是 BigQuery 中的表。

template_path, parameter_values =  automl_tabular_utils.get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters(...)

以下是部分 get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters 参数:

参数名称 类型 定义
data_source_csv_filenames 字符串 存储在 Cloud Storage 中的 CSV 的 URI。
data_source_bigquery_table_path 字符串 BigQuery 表的 URI。
dataflow_service_account 字符串 (可选)用于运行 Dataflow 作业的自定义服务账号。Dataflow 作业可以配置为使用专用 IP 和特定 VPC 子网。 此参数充当默认 Dataflow 工作器服务账号的替换值。
study_spec_parameters_override List[Dict[String, Any]] (可选)调节超参数的替换值。此参数可以为空,也可以包含一个或多个可能的超参数。如果未设置超参数值,Vertex AI 会使用超参数的默认调节范围。

如果您要使用 study_spec_parameters_override 参数配置超参数,则可以使用 Vertex AI 的辅助函数 get_tabnet_study_spec_parameters_override。该函数具有以下输入:

  • dataset_size_bucket:数据集大小的存储桶
    • “small”:< 100 万行
    • “medium”:100 万 - 1 亿行
    • “large”:> 1 亿行
  • training_budget_bucket:训练预算的存储桶
    • “small”:< $600
    • “medium”:$600 - $2400
    • “large”:> $2400
  • prediction_type:所需的预测类型

get_tabnet_study_spec_parameters_override 函数会返回超参数和范围的列表。

以下示例展示了如何使用 get_tabnet_study_spec_parameters_override 函数:

study_spec_parameters_override = automl_tabular_utils.get_tabnet_study_spec_parameters_override(
    dataset_size_bucket="small",
    prediction_type="classification",
    training_budget_bucket="small",
)

使用 CustomJob 训练模型

以下示例代码演示了如何运行 CustomJob 流水线:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

您可以使用 pipeline_job.run() 中的可选 service_account 参数,将 Vertex AI Pipelines 服务账号设置为您选择的账号。

流水线和参数值由以下函数定义。训练数据可以是 Cloud Storage 中的 CSV 文件,也可以是 BigQuery 中的表。

template_path, parameter_values = automl_tabular_utils.get_tabnet_trainer_pipeline_and_parameters(...)

以下是部分 get_tabnet_trainer_pipeline_and_parameters 参数:

参数名称 类型 定义
data_source_csv_filenames 字符串 存储在 Cloud Storage 中的 CSV 的 URI。
data_source_bigquery_table_path 字符串 BigQuery 表的 URI。
dataflow_service_account 字符串 (可选)用于运行 Dataflow 作业的自定义服务账号。Dataflow 作业可以配置为使用专用 IP 和特定 VPC 子网。 此参数充当默认 Dataflow 工作器服务账号的替换值。

后续步骤

准备好通过分类或回归模型进行预测后,您可以采用以下两种方法开始预测: