This page shows you how to train a classification or regression model from a tabular dataset with the Tabular Workflow for TabNet.
Two versions of the Tabular Workflow for TabNet are available:
- HyperparameterTuningJob searches for the best set of hyperparameter values to use for model training.
- CustomJob lets you specify the hyperparameter values to use for model training. If you know exactly which hyperparameter values you need, you can specify them instead of searching for them and save on training resources.
To learn about the service accounts used by this workflow, see Service accounts for Tabular Workflows.
Workflow APIs
This workflow uses the following APIs:
- Vertex AI
- Dataflow
- Compute Engine
- Cloud Storage
Train a model with HyperparameterTuningJob
The following sample code demonstrates how you can run a HyperparameterTuningJob pipeline:
pipeline_job = aiplatform.PipelineJob(
...
template_path=template_path,
parameter_values=parameter_values,
...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)
The optional service_account
parameter in pipeline_job.run()
lets you set the
Vertex AI Pipelines service account to an account of your choice.
The pipeline and the parameter values are defined by the following function. The training data can be either a CSV file in Cloud Storage or a table in BigQuery.
template_path, parameter_values = automl_tabular_utils.get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters(...)
The following is a subset of get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters
parameters:
Parameter name | Type | Definition |
---|---|---|
data_source_csv_filenames |
String | A URI for a CSV stored in Cloud Storage. |
data_source_bigquery_table_path |
String | A URI for a BigQuery table. |
dataflow_service_account |
String | (Optional) Custom service account to run Dataflow jobs. The Dataflow job can be configured to use private IPs and a specific VPC subnet. This parameter acts as an override for the default Dataflow worker service account. |
study_spec_parameters_override |
List[Dict[String, Any]] | (Optional) An override for tuning hyperparameters. This parameter can be empty or contain one or more of the possible hyperparameters. If a hyperparameter value is not set, Vertex AI uses the default tuning range for the hyperparameter. |
If you want to configure the hyperparameters using the study_spec_parameters_override
parameter,
you can use Vertex AI's helper function get_tabnet_study_spec_parameters_override
.
The function has the following inputs:
dataset_size_bucket
: A bucket for the dataset size- 'small': < 1M rows
- 'medium': 1M - 100M rows
- 'large': > 100M rows
training_budget_bucket
: A bucket for the training budget- 'small': < $600
- 'medium': $600 - $2400
- 'large': > $2400
prediction_type
: The desired prediction type
The get_tabnet_study_spec_parameters_override
function returns a list of
hyperparameters and ranges.
The following is an example of how the get_tabnet_study_spec_parameters_override
function can be used:
study_spec_parameters_override = automl_tabular_utils.get_tabnet_study_spec_parameters_override(
dataset_size_bucket="small",
prediction_type="classification",
training_budget_bucket="small",
)
Train a model with CustomJob
The following sample code demonstrates how you can run a CustomJob pipeline:
pipeline_job = aiplatform.PipelineJob(
...
template_path=template_path,
parameter_values=parameter_values,
...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)
The optional service_account
parameter in pipeline_job.run()
lets you set the
Vertex AI Pipelines service account to an account of your choice.
The pipeline and the parameter values are defined by the following function. The training data can be either a CSV file in Cloud Storage or a table in BigQuery.
template_path, parameter_values = automl_tabular_utils.get_tabnet_trainer_pipeline_and_parameters(...)
The following is a subset of get_tabnet_trainer_pipeline_and_parameters
parameters:
Parameter name | Type | Definition |
---|---|---|
data_source_csv_filenames |
String | A URI for a CSV stored in Cloud Storage. |
data_source_bigquery_table_path |
String | A URI for a BigQuery table. |
dataflow_service_account |
String | (Optional) Custom service account to run Dataflow jobs. The Dataflow job can be configured to use private IPs and a specific VPC subnet. This parameter acts as an override for the default Dataflow worker service account. |
What's next
Once you're ready to make predictions with your classification or regression model, you have two options:
- Make online (real-time) predictions using your model
- Get batch predictions directly from your model.
- Learn about pricing for model training.