Train a model with TabNet

This page shows you how to train a classification or regression model from a tabular dataset with the Tabular Workflow for TabNet.

Two versions of the Tabular Workflow for TabNet are available:

  • HyperparameterTuningJob searches for the best set of hyperparameter values to use for model training.
  • CustomJob lets you specify the hyperparameter values to use for model training. If you know exactly which hyperparameter values you need, you can specify them instead of searching for them and save on training resources.

To learn about the service accounts used by this workflow, see Service accounts for Tabular Workflows.

Workflow APIs

This workflow uses the following APIs:

  • Vertex AI
  • Dataflow
  • Compute Engine
  • Cloud Storage

Train a model with HyperparameterTuningJob

The following sample code demonstrates how you can run a HyperparameterTuningJob pipeline:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

The optional service_account parameter in pipeline_job.run() lets you set the Vertex AI Pipelines service account to an account of your choice.

The pipeline and the parameter values are defined by the following function. The training data can be either a CSV file in Cloud Storage or a table in BigQuery.

template_path, parameter_values =  automl_tabular_utils.get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters(...)

The following is a subset of get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters parameters:

Parameter name Type Definition
data_source_csv_filenames String A URI for a CSV stored in Cloud Storage.
data_source_bigquery_table_path String A URI for a BigQuery table.
dataflow_service_account String (Optional) Custom service account to run Dataflow jobs. The Dataflow job can be configured to use private IPs and a specific VPC subnet. This parameter acts as an override for the default Dataflow worker service account.
study_spec_parameters_override List[Dict[String, Any]] (Optional) An override for tuning hyperparameters. This parameter can be empty or contain one or more of the possible hyperparameters. If a hyperparameter value is not set, Vertex AI uses the default tuning range for the hyperparameter.

If you want to configure the hyperparameters using the study_spec_parameters_override parameter, you can use Vertex AI's helper function get_tabnet_study_spec_parameters_override. The function has the following inputs:

  • dataset_size_bucket: A bucket for the dataset size
    • 'small': < 1M rows
    • 'medium': 1M - 100M rows
    • 'large': > 100M rows
  • training_budget_bucket: A bucket for the training budget
    • 'small': < $600
    • 'medium': $600 - $2400
    • 'large': > $2400
  • prediction_type: The desired prediction type

The get_tabnet_study_spec_parameters_override function returns a list of hyperparameters and ranges.

The following is an example of how the get_tabnet_study_spec_parameters_override function can be used:

study_spec_parameters_override = automl_tabular_utils.get_tabnet_study_spec_parameters_override(
    dataset_size_bucket="small",
    prediction_type="classification",
    training_budget_bucket="small",
)

Train a model with CustomJob

The following sample code demonstrates how you can run a CustomJob pipeline:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

The optional service_account parameter in pipeline_job.run() lets you set the Vertex AI Pipelines service account to an account of your choice.

The pipeline and the parameter values are defined by the following function. The training data can be either a CSV file in Cloud Storage or a table in BigQuery.

template_path, parameter_values = automl_tabular_utils.get_tabnet_trainer_pipeline_and_parameters(...)

The following is a subset of get_tabnet_trainer_pipeline_and_parameters parameters:

Parameter name Type Definition
data_source_csv_filenames String A URI for a CSV stored in Cloud Storage.
data_source_bigquery_table_path String A URI for a BigQuery table.
dataflow_service_account String (Optional) Custom service account to run Dataflow jobs. The Dataflow job can be configured to use private IPs and a specific VPC subnet. This parameter acts as an override for the default Dataflow worker service account.

What's next

Once you're ready to make predictions with your classification or regression model, you have two options: