本页面介绍了如何使用 Wide & Deep 的表格工作流从表格数据集训练分类或回归模型。
Wide & Deep 的表格工作流有两个版本:
- HyperparameterTuningJob 可搜索用于模型训练的最佳超参数值集。
 - CustomJob 可让您指定用于模型训练的超参数值。如果您确切地知道需要哪些超参数值,请指定它们(而不是进行搜索),从而节省训练资源。
 
如需了解此工作流使用的服务账号,请参阅 Tabular Workflows 的服务账号。
Workflow API
此工作流使用以下 API:
- Vertex AI
 - Dataflow
 - Compute Engine
 - Cloud Storage
 
使用 HyperparameterTuningJob 训练模型
以下示例代码演示了如何运行 HyperparameterTuningJob 流水线:
pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)
您可以使用 pipeline_job.run() 中的可选 service_account 参数,将 Vertex AI Pipelines 服务账号设置为您选择的账号。
流水线和参数值由以下函数定义。训练数据可以是 Cloud Storage 中的 CSV 文件,也可以是 BigQuery 中的表。
template_path, parameter_values =  automl_tabular_utils.get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters(...)
以下是部分 get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters 参数:
| 参数名称 | 类型 | 定义 | 
|---|---|---|
data_source_csv_filenames | 
字符串 | 存储在 Cloud Storage 中的 CSV 的 URI。 | 
data_source_bigquery_table_path | 
字符串 | BigQuery 表的 URI。 | 
dataflow_service_account | 
字符串 | (可选)用于运行 Dataflow 作业的自定义服务账号。Dataflow 作业可以配置为使用专用 IP 和特定 VPC 子网。 此参数充当默认 Dataflow 工作器服务账号的替换值。 | 
study_spec_parameters_override | 
List[Dict[String, Any]] | (可选)调节超参数的替换值。此参数可以为空,也可以包含一个或多个可能的超参数。如果未设置超参数值,Vertex AI 会使用超参数的默认调节范围。 | 
如需使用 study_spec_parameters_override 参数配置超参数,请使用 Vertex AI 的辅助函数 get_wide_and_deep_study_spec_parameters_override。此函数会返回超参数和范围的列表。
以下示例展示了如何使用 get_wide_and_deep_study_spec_parameters_override 函数:
study_spec_parameters_override = automl_tabular_utils.get_wide_and_deep_study_spec_parameters_override()
使用 CustomJob 训练模型
以下示例代码演示了如何运行 CustomJob 流水线:
pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)
您可以使用 pipeline_job.run() 中的可选 service_account 参数,将 Vertex AI Pipelines 服务账号设置为您选择的账号。
流水线和参数值由以下函数定义。训练数据可以是 Cloud Storage 中的 CSV 文件,也可以是 BigQuery 中的表。
template_path, parameter_values = automl_tabular_utils.get_wide_and_deep_trainer_pipeline_and_parameters(...)
以下是部分 get_wide_and_deep_trainer_pipeline_and_parameters 参数:
| 参数名称 | 类型 | 定义 | 
|---|---|---|
data_source_csv_filenames | 
字符串 | 存储在 Cloud Storage 中的 CSV 的 URI。 | 
data_source_bigquery_table_path | 
字符串 | BigQuery 表的 URI。 | 
dataflow_service_account | 
字符串 | (可选)用于运行 Dataflow 作业的自定义服务账号。Dataflow 作业可以配置为使用专用 IP 和特定 VPC 子网。 此参数充当默认 Dataflow 工作器服务账号的替换值。 | 
后续步骤
准备好通过分类或回归模型进行推理后,您可以采用以下两种方法开始推理: