Entrenar un modelo con TabNet

En esta página se explica cómo entrenar un modelo de clasificación o regresión a partir de un conjunto de datos tabulares con el flujo de trabajo tabular de TabNet.

Hay dos versiones del flujo de trabajo tabular para TabNet:

  • HyperparameterTuningJob busca el mejor conjunto de valores de hiperparámetros que se puede usar para entrenar un modelo.
  • CustomJob te permite especificar los valores de los hiperparámetros que se usarán para entrenar el modelo. Si sabes exactamente qué valores de hiperparámetros necesitas, especifícalos en lugar de buscarlos y ahorra recursos de entrenamiento.

Para obtener información sobre las cuentas de servicio que usa este flujo de trabajo, consulta Cuentas de servicio para flujos de trabajo tabulares.

APIs de flujo de trabajo

Este flujo de trabajo usa las siguientes APIs:

  • Vertex AI
  • Dataflow
  • Compute Engine
  • Cloud Storage

Entrenar un modelo con HyperparameterTuningJob

En el siguiente código de ejemplo se muestra cómo ejecutar una canalización HyperparameterTuningJob:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

El parámetro opcional service_account de pipeline_job.run() te permite definir la cuenta de servicio de Vertex AI Pipelines como la que elijas.

La siguiente función define la canalización y los valores de los parámetros. Los datos de entrenamiento pueden ser un archivo CSV en Cloud Storage o una tabla en BigQuery.

template_path, parameter_values =  automl_tabular_utils.get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters(...)

A continuación, se muestra un subconjunto de parámetros de get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters:

Nombre del parámetro Tipo Definición
data_source_csv_filenames Cadena Un URI de un archivo CSV almacenado en Cloud Storage.
data_source_bigquery_table_path Cadena URI de una tabla de BigQuery.
dataflow_service_account Cadena (Opcional) Cuenta de servicio personalizada para ejecutar trabajos de Dataflow. El trabajo de Dataflow se puede configurar para que use IPs privadas y una subred de VPC específica. Este parámetro anula la cuenta de servicio de trabajador de Dataflow predeterminada.
study_spec_parameters_override List[Dict[String, Any]] (Opcional) Una anulación para ajustar los hiperparámetros. Este parámetro puede estar vacío o contener uno o varios de los hiperparámetros posibles. Si no se define un valor de hiperparámetro, Vertex AI usa el intervalo de ajuste predeterminado del hiperparámetro.

Para configurar los hiperparámetros con el parámetro study_spec_parameters_override, usa la función auxiliar get_tabnet_study_spec_parameters_override de Vertex AI. La función tiene las siguientes entradas:

  • dataset_size_bucket: Un contenedor para el tamaño del conjunto de datos
    • 'small': < 1 millón de filas
    • "medium": de 1 a 100 millones de filas
    • "large": > 100 millones de filas
  • training_budget_bucket: un contenedor para el presupuesto de formación
    • 'small': < 600 $
    • 'medium': 600 - 2400 $
    • 'large': > 2400 $
  • prediction_type: el tipo de inferencia que quieres

La función get_tabnet_study_spec_parameters_override devuelve una lista de hiperparámetros e intervalos.

A continuación, se muestra un ejemplo de cómo usar la función get_tabnet_study_spec_parameters_override:

study_spec_parameters_override = automl_tabular_utils.get_tabnet_study_spec_parameters_override(
    dataset_size_bucket="small",
    prediction_type="classification",
    training_budget_bucket="small",
)

Entrenar un modelo con CustomJob

En el siguiente código de ejemplo se muestra cómo ejecutar una canalización de CustomJob:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

El parámetro opcional service_account de pipeline_job.run() te permite definir la cuenta de servicio de Vertex AI Pipelines como la que elijas.

La siguiente función define la canalización y los valores de los parámetros. Los datos de entrenamiento pueden ser un archivo CSV en Cloud Storage o una tabla en BigQuery.

template_path, parameter_values = automl_tabular_utils.get_tabnet_trainer_pipeline_and_parameters(...)

A continuación, se muestra un subconjunto de parámetros de get_tabnet_trainer_pipeline_and_parameters:

Nombre del parámetro Tipo Definición
data_source_csv_filenames Cadena Un URI de un archivo CSV almacenado en Cloud Storage.
data_source_bigquery_table_path Cadena URI de una tabla de BigQuery.
dataflow_service_account Cadena (Opcional) Cuenta de servicio personalizada para ejecutar trabajos de Dataflow. El trabajo de Dataflow se puede configurar para que use IPs privadas y una subred de VPC específica. Este parámetro anula la cuenta de servicio de trabajador de Dataflow predeterminada.

Siguientes pasos

Cuando quieras hacer inferencias con tu modelo de clasificación o regresión, tienes dos opciones: