Modell mit TabNet trainieren

Auf dieser Seite erfahren Sie, wie Sie mit dem tabellarischen Workflow für TabNet ein Klassifizierungs- oder Regressionsmodell aus einem tabellarischen Dataset trainieren.

Es gibt zwei Versionen des tabellarischen Workflows für TabNet:

  • HyperparameterTuningJob sucht nach den besten Hyperparameter-Werten für das Modelltraining.
  • Mit CustomJob können Sie die Hyperparameter-Werte angeben, die für das Modelltraining verwendet werden sollen. Wenn Sie genau wissen, welche Hyperparameter-Werte Sie benötigen, können Sie diese angeben, anstatt nach ihnen zu suchen. Sie sparen damit Trainingsressourcen.

Weitere Informationen zu den von diesem Workflow verwendeten Dienstkonten finden Sie unter Dienstkonten für tabellarische Workflows.

Workflow-APIs

Dieser Workflow verwendet folgende APIs:

  • Vertex AI
  • Dataflow
  • Compute Engine
  • Cloud Storage

Modell mit HyperparameterTuningJob trainieren

Der folgende Beispielcode zeigt das Ausführen einer HyperparameterTuningJob-Pipeline:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

Mit dem optionalen Parameter service_account in pipeline_job.run() können Sie das Vertex AI Pipelines-Dienstkonto auf ein Konto Ihrer Wahl festlegen.

Die Pipeline und die Parameterwerte werden durch die folgende Funktion definiert. Die Trainingsdaten können entweder eine CSV-Datei in Cloud Storage oder eine Tabelle in BigQuery sein.

template_path, parameter_values =  automl_tabular_utils.get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters(...)

Im Folgenden finden Sie einen Teil der get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters-Parameter:

Parametername Typ Definition
data_source_csv_filenames String Ein URI für eine in Cloud Storage gespeicherte CSV-Datei.
data_source_bigquery_table_path String Ein URI für eine BigQuery-Tabelle.
dataflow_service_account String (Optional) Benutzerdefiniertes Dienstkonto zum Ausführen von Dataflow-Jobs. Der Dataflow-Job kann so konfiguriert werden, dass private IP-Adressen und ein bestimmtes VPC-Subnetz verwendet werden. Dieser Parameter dient als Überschreibung für das Standarddienstkonto des Dataflow-Workers.
study_spec_parameters_override List[Dict[String, Any]] (Optional) Eine Überschreibung zum Optimieren von Hyperparametern. Dieser Parameter kann leer sein oder einen oder mehrere mögliche Hyperparameter enthalten. Wenn kein Hyperparameterwert festgelegt ist, verwendet Vertex AI den Standardabstimmungsbereich für den Hyperparameter.

Wenn Sie die Hyperparameter mit dem Parameter study_spec_parameters_override konfigurieren möchten, können Sie die Hilfsfunktion get_tabnet_study_spec_parameters_override von Vertex AI verwenden. Die Funktion hat folgende Eingaben:

  • dataset_size_bucket: Einen Bucket für die Dataset-Größe
    • „klein“: < 1 Million Zeilen
    • „mittel“: 1 Mio. - 100 Mio. Zeilen
    • „groß“: > 100 Millionen Zeilen
  • training_budget_bucket: Einen Bucket für das Trainingsbudget
    • „klein“: < 600 $
    • „mittel“: 600 $ - 2400 $
    • „groß“: > 2.400 $
  • prediction_type: Der gewünschte Vorhersagetyp

get_tabnet_study_spec_parameters_override gibt eine Liste von Hyperparametern und Bereichen zurück.

Im Folgenden finden Sie ein Beispiel für die Verwendung der Funktion get_tabnet_study_spec_parameters_override:

study_spec_parameters_override = automl_tabular_utils.get_tabnet_study_spec_parameters_override(
    dataset_size_bucket="small",
    prediction_type="classification",
    training_budget_bucket="small",
)

Modell mit CustomJob trainieren

Der folgende Beispielcode zeigt das Ausführen einer CustomJob-Pipeline:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

Mit dem optionalen Parameter service_account in pipeline_job.run() können Sie das Vertex AI Pipelines-Dienstkonto auf ein Konto Ihrer Wahl festlegen.

Die Pipeline und die Parameterwerte werden durch die folgende Funktion definiert. Die Trainingsdaten können entweder eine CSV-Datei in Cloud Storage oder eine Tabelle in BigQuery sein.

template_path, parameter_values = automl_tabular_utils.get_tabnet_trainer_pipeline_and_parameters(...)

Im Folgenden finden Sie einen Teil der get_tabnet_trainer_pipeline_and_parameters-Parameter:

Parametername Typ Definition
data_source_csv_filenames String Ein URI für eine in Cloud Storage gespeicherte CSV-Datei.
data_source_bigquery_table_path String Ein URI für eine BigQuery-Tabelle.
dataflow_service_account String (Optional) Benutzerdefiniertes Dienstkonto zum Ausführen von Dataflow-Jobs. Der Dataflow-Job kann so konfiguriert werden, dass private IP-Adressen und ein bestimmtes VPC-Subnetz verwendet werden. Dieser Parameter dient als Überschreibung für das Standarddienstkonto des Dataflow-Workers.

Nächste Schritte

Sobald Sie bereit sind, Vorhersagen mit Ihrem Klassifizierungs- oder Regressionsmodell zu treffen, haben Sie zwei Möglichkeiten: