本页面介绍了如何使用 TabNet 的表格工作流训练表格数据集中的分类或回归模型。
TabNet 的表格工作流有两个版本:
- HyperparameterTuningJob 可搜索用于模型训练的最佳超参数值集。
- CustomJob 可让您指定用于模型训练的超参数值。如果您确切地知道需要哪些超参数值,则可以指定它们(而不是进行搜索),从而节省训练资源。
如需了解此工作流使用的服务账号,请参阅表格工作流的服务账号。
Workflow API
此工作流使用以下 API:
- Vertex AI
- Dataflow
- Compute Engine
- Cloud Storage
使用 HyperparameterTuningJob 训练模型
以下示例代码演示了如何运行 HyperparameterTuningJob 流水线:
pipeline_job = aiplatform.PipelineJob(
...
template_path=template_path,
parameter_values=parameter_values,
...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)
您可以使用 pipeline_job.run()
中的可选 service_account
参数,将 Vertex AI Pipelines 服务账号设置为您选择的账号。
流水线和参数值由以下函数定义。训练数据可以是 Cloud Storage 中的 CSV 文件,也可以是 BigQuery 中的表。
template_path, parameter_values = automl_tabular_utils.get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters(...)
以下是部分 get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters
参数:
参数名称 | 类型 | 定义 |
---|---|---|
data_source_csv_filenames |
字符串 | 存储在 Cloud Storage 中的 CSV 的 URI。 |
data_source_bigquery_table_path |
字符串 | BigQuery 表的 URI。 |
dataflow_service_account |
字符串 | (可选)用于运行 Dataflow 作业的自定义服务账号。Dataflow 作业可以配置为使用专用 IP 和特定 VPC 子网。 此参数充当默认 Dataflow 工作器服务账号的替换值。 |
study_spec_parameters_override |
List[Dict[String, Any]] | (可选)调节超参数的替换值。此参数可以为空,也可以包含一个或多个可能的超参数。如果未设置超参数值,Vertex AI 会使用超参数的默认调节范围。 |
如果您要使用 study_spec_parameters_override
参数配置超参数,则可以使用 Vertex AI 的辅助函数 get_tabnet_study_spec_parameters_override
。该函数具有以下输入:
dataset_size_bucket
:数据集大小的存储桶- “small”:< 100 万行
- “medium”:100 万 - 1 亿行
- “large”:> 1 亿行
training_budget_bucket
:训练预算的存储桶- “small”:< $600
- “medium”:$600 - $2400
- “large”:> $2400
prediction_type
:所需的预测类型
get_tabnet_study_spec_parameters_override
函数会返回超参数和范围的列表。
以下示例展示了如何使用 get_tabnet_study_spec_parameters_override
函数:
study_spec_parameters_override = automl_tabular_utils.get_tabnet_study_spec_parameters_override(
dataset_size_bucket="small",
prediction_type="classification",
training_budget_bucket="small",
)
使用 CustomJob 训练模型
以下示例代码演示了如何运行 CustomJob 流水线:
pipeline_job = aiplatform.PipelineJob(
...
template_path=template_path,
parameter_values=parameter_values,
...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)
您可以使用 pipeline_job.run()
中的可选 service_account
参数,将 Vertex AI Pipelines 服务账号设置为您选择的账号。
流水线和参数值由以下函数定义。训练数据可以是 Cloud Storage 中的 CSV 文件,也可以是 BigQuery 中的表。
template_path, parameter_values = automl_tabular_utils.get_tabnet_trainer_pipeline_and_parameters(...)
以下是部分 get_tabnet_trainer_pipeline_and_parameters
参数:
参数名称 | 类型 | 定义 |
---|---|---|
data_source_csv_filenames |
字符串 | 存储在 Cloud Storage 中的 CSV 的 URI。 |
data_source_bigquery_table_path |
字符串 | BigQuery 表的 URI。 |
dataflow_service_account |
字符串 | (可选)用于运行 Dataflow 作业的自定义服务账号。Dataflow 作业可以配置为使用专用 IP 和特定 VPC 子网。 此参数充当默认 Dataflow 工作器服务账号的替换值。 |
后续步骤
准备好通过分类或回归模型进行预测后,您可以采用以下两种方法开始预测: