Train a model with Wide & Deep

This page shows you how to train a classification or regression model from a tabular dataset with the Tabular Workflow for Wide & Deep.

Two versions of the Tabular Workflow for Wide & Deep are available:

  • HyperparameterTuningJob searches for the best set of hyperparameter values to use for model training.
  • CustomJob lets you specify the hyperparameter values to use for model training. If you know exactly which hyperparameter values you need, you can specify them instead of searching for them and save on training resources.

To learn about the service accounts used by this workflow, see Service accounts for Tabular Workflows.

Workflow APIs

This workflow uses the following APIs:

  • Vertex AI
  • Dataflow
  • Compute Engine
  • Cloud Storage

Train a model with HyperparameterTuningJob

The following sample code demonstrates how you can run a HyperparameterTuningJob pipeline:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

The optional service_account parameter in pipeline_job.run() lets you set the Vertex AI Pipelines service account to an account of your choice.

The pipeline and the parameter values are defined by the following function. The training data can be either a CSV file in Cloud Storage or a table in BigQuery.

template_path, parameter_values =  automl_tabular_utils.get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters(...)

The following is a subset of get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters parameters:

Parameter name Type Definition
data_source_csv_filenames String A URI for a CSV stored in Cloud Storage.
data_source_bigquery_table_path String A URI for a BigQuery table.
dataflow_service_account String (Optional) Custom service account to run Dataflow jobs. The Dataflow job can be configured to use private IPs and a specific VPC subnet. This parameter acts as an override for the default Dataflow worker service account.
study_spec_parameters_override List[Dict[String, Any]] (Optional) An override for tuning hyperparameters. This parameter can be empty or contain one or more of the possible hyperparameters. If a hyperparameter value is not set, Vertex AI uses the default tuning range for the hyperparameter.

If you want to configure the hyperparameters using the study_spec_parameters_override parameter, you can use Vertex AI's helper function get_wide_and_deep_study_spec_parameters_override. This function returns a list of hyperparameters and ranges.

The following is an example of how the get_wide_and_deep_study_spec_parameters_override function can be used:

study_spec_parameters_override = automl_tabular_utils.get_wide_and_deep_study_spec_parameters_override()

Train a model with CustomJob

The following sample code demonstrates how you can run a CustomJob pipeline:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

The optional service_account parameter in pipeline_job.run() lets you set the Vertex AI Pipelines service account to an account of your choice.

The pipeline and the parameter values are defined by the following function. The training data can be either a CSV file in Cloud Storage or a table in BigQuery.

template_path, parameter_values = automl_tabular_utils.get_wide_and_deep_trainer_pipeline_and_parameters(...)

The following is a subset of get_wide_and_deep_trainer_pipeline_and_parameters parameters:

Parameter name Type Definition
data_source_csv_filenames String A URI for a CSV stored in Cloud Storage.
data_source_bigquery_table_path String A URI for a BigQuery table.
dataflow_service_account String (Optional) Custom service account to run Dataflow jobs. The Dataflow job can be configured to use private IPs and a specific VPC subnet. This parameter acts as an override for the default Dataflow worker service account.

What's next

Once you're ready to make predictions with your classification or regression model, you have two options: