本页面介绍了如何使用 Wide & Deep 的表格工作流从表格数据集训练分类或回归模型。
Wide & Deep 的表格工作流有两个版本:
- HyperparameterTuneJob 可搜索用于模型训练的最佳超参数值集。
- CustomJob 可让您指定用于模型训练的超参数值。如果您确切地知道需要哪些超参数值,则可以指定它们(而不是进行搜索),从而节省训练资源。
如需了解此工作流使用的服务账号,请参阅表格工作流的服务账号。
Workflow API
此工作流使用以下 API:
- Vertex AI
- Dataflow
- Compute Engine
- Cloud Storage
使用 HyperparameterTuningJob 训练模型
以下示例代码演示了如何运行 HyperparameterTuningJob 流水线:
pipeline_job = aiplatform.PipelineJob(
...
template_path=template_path,
parameter_values=parameter_values,
...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)
您可以使用 pipeline_job.run()
中的可选 service_account
参数,将 Vertex AI Pipelines 服务账号设置为您选择的账号。
流水线和参数值由以下函数定义。训练数据可以是 Cloud Storage 中的 CSV 文件,也可以是 BigQuery 中的表。
template_path, parameter_values = automl_tabular_utils.get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters(...)
以下是部分 get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters
参数:
参数名称 | 类型 | 定义 |
---|---|---|
data_source_csv_filenames |
字符串 | 存储在 Cloud Storage 中的 CSV 的 URI。 |
data_source_bigquery_table_path |
字符串 | BigQuery 表的 URI。 |
dataflow_service_account |
字符串 | (可选)用于运行 Dataflow 作业的自定义服务账号。Dataflow 作业可以配置为使用专用 IP 和特定 VPC 子网。 此参数充当默认 Dataflow 工作器服务账号的替换值。 |
study_spec_parameters_override |
List[Dict[String, Any]] | (可选)调节超参数的替换值。此参数可以为空,也可以包含一个或多个可能的超参数。如果未设置超参数值,Vertex AI 会使用超参数的默认调节范围。 |
如果您要使用 study_spec_parameters_override
参数配置超参数,则可以使用 Vertex AI 的辅助函数 get_wide_and_deep_study_spec_parameters_override
。此函数会返回超参数和范围的列表。
以下示例展示了如何使用 get_wide_and_deep_study_spec_parameters_override
函数:
study_spec_parameters_override = automl_tabular_utils.get_wide_and_deep_study_spec_parameters_override()
使用 CustomJob 训练模型
以下示例代码演示了如何运行 CustomJob 流水线:
pipeline_job = aiplatform.PipelineJob(
...
template_path=template_path,
parameter_values=parameter_values,
...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)
您可以使用 pipeline_job.run()
中的可选 service_account
参数,将 Vertex AI Pipelines 服务账号设置为您选择的账号。
流水线和参数值由以下函数定义。训练数据可以是 Cloud Storage 中的 CSV 文件,也可以是 BigQuery 中的表。
template_path, parameter_values = automl_tabular_utils.get_wide_and_deep_trainer_pipeline_and_parameters(...)
以下是部分 get_wide_and_deep_trainer_pipeline_and_parameters
参数:
参数名称 | 类型 | 定义 |
---|---|---|
data_source_csv_filenames |
字符串 | 存储在 Cloud Storage 中的 CSV 的 URI。 |
data_source_bigquery_table_path |
字符串 | BigQuery 表的 URI。 |
dataflow_service_account |
字符串 | (可选)用于运行 Dataflow 作业的自定义服务账号。Dataflow 作业可以配置为使用专用 IP 和特定 VPC 子网。 此参数充当默认 Dataflow 工作器服务账号的替换值。 |
后续步骤
准备好通过分类或回归模型进行预测后,您可以采用以下两种方法开始预测: