Creazione di una pipeline di addestramento per la regressione tabulare

Crea una pipeline di addestramento per la regressione tabulare utilizzando il metodo create_training_pipeline.

Per saperne di più

Per la documentazione dettagliata che include questo esempio di codice, consulta quanto segue:

Esempio di codice

Java

Prima di provare questo esempio, segui le istruzioni di configurazione Java riportate nella guida rapida all'utilizzo delle librerie client di Vertex AI. Per ulteriori informazioni, consulta la documentazione di riferimento dell'API Java di Vertex AI.

Per autenticarti in Vertex AI, configura le Credenziali predefinite dell'applicazione. Per ulteriori informazioni, consulta Configurare l'autenticazione per un ambiente di sviluppo locale.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.TimestampTransformation;
import com.google.rpc.Status;
import java.io.IOException;
import java.util.ArrayList;

public class CreateTrainingPipelineTabularRegressionSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "YOUR_PROJECT_ID";
    String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
    String datasetId = "YOUR_DATASET_ID";
    String targetColumn = "TARGET_COLUMN";
    createTrainingPipelineTableRegression(project, modelDisplayName, datasetId, targetColumn);
  }

  static void createTrainingPipelineTableRegression(
      String project, String modelDisplayName, String datasetId, String targetColumn)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";

      // Set the columns used for training and their data types
      ArrayList<Transformation> tranformations = new ArrayList<>();
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("STRING_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("INTEGER_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_REPEATED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("NUMERIC_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("BOOLEAN_2unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setTimestamp(
                  TimestampTransformation.newBuilder()
                      .setColumnName("TIMESTAMP_1unique_NULLABLE")
                      .setInvalidValuesAllowed(true))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("DATE_1unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("TIME_1unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setTimestamp(
                  TimestampTransformation.newBuilder()
                      .setColumnName("DATETIME_1unique_NULLABLE")
                      .setInvalidValuesAllowed(true))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.STRING_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE"))
              .build());

      AutoMlTablesInputs trainingTaskInputs =
          AutoMlTablesInputs.newBuilder()
              .addAllTransformations(tranformations)
              .setTargetColumn(targetColumn)
              .setPredictionType("regression")
              .setTrainBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              // supported regression optimisation objectives: minimize-rmse,
              // minimize-mae, minimize-rmsle
              .setOptimizationObjective("minimize-rmse")
              .build();

      FractionSplit fractionSplit =
          FractionSplit.newBuilder()
              .setTrainingFraction(0.8)
              .setValidationFraction(0.1)
              .setTestFraction(0.1)
              .build();

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder()
              .setDatasetId(datasetId)
              .setFractionSplit(fractionSplit)
              .build();
      Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();

      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(modelDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(modelToUpload)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Tabular Regression Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
      System.out.format(
          "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());

      System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
      System.out.format(
          "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());

      FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format(
          "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());
      System.out.format(
          "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());

      FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList().toString());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList().toString());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList().toString());

      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());
      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();

      System.out.println("\tPredict Schemata");
      System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format(
          "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format(
          "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (Model.ExportFormat supportedExportFormat :
          modelResponse.getSupportedExportFormatsList()) {
        System.out.println("\tSupported Export Format");
        System.out.format("\t\tId: %s\n", supportedExportFormat.getId());
      }
      ModelContainerSpec containerSpec = modelResponse.getContainerSpec();

      System.out.println("\tContainer Spec");
      System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());
      System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());
      System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());
      System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());
      System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());

      for (EnvVar envVar : containerSpec.getEnvList()) {
        System.out.println("\t\tEnv");
        System.out.format("\t\t\tName: %s\n", envVar.getName());
        System.out.format("\t\t\tValue: %s\n", envVar.getValue());
      }

      for (Port port : containerSpec.getPortsList()) {
        System.out.println("\t\tPort");
        System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("\tDeployed Model");
        System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

Prima di provare questo esempio, segui le istruzioni di configurazione Node.js riportate nella guida rapida all'utilizzo delle librerie client di Vertex AI. Per ulteriori informazioni, consulta la documentazione di riferimento dell'API Node.js di Vertex AI.

Per autenticarti a Vertex AI, configura le Credenziali predefinite dell'applicazione. Per ulteriori informazioni, consulta Configurare l'autenticazione per un ambiente di sviluppo locale.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const targetColumn = 'YOUR_TARGET_COLUMN';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;
// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineTablesRegression() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const transformations = [
    {auto: {column_name: 'STRING_5000unique_NULLABLE'}},
    {auto: {column_name: 'INTEGER_5000unique_NULLABLE'}},
    {auto: {column_name: 'FLOAT_5000unique_NULLABLE'}},
    {auto: {column_name: 'FLOAT_5000unique_REPEATED'}},
    {auto: {column_name: 'NUMERIC_5000unique_NULLABLE'}},
    {auto: {column_name: 'BOOLEAN_2unique_NULLABLE'}},
    {
      timestamp: {
        column_name: 'TIMESTAMP_1unique_NULLABLE',
        invalid_values_allowed: true,
      },
    },
    {auto: {column_name: 'DATE_1unique_NULLABLE'}},
    {auto: {column_name: 'TIME_1unique_NULLABLE'}},
    {
      timestamp: {
        column_name: 'DATETIME_1unique_NULLABLE',
        invalid_values_allowed: true,
      },
    },
    {auto: {column_name: 'STRUCT_NULLABLE.STRING_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_REPEATED'}},
    {auto: {column_name: 'STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.BOOLEAN_2unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE'}},
  ];

  const trainingTaskInputsObj = new definition.AutoMlTablesInputs({
    transformations,
    targetColumn,
    predictionType: 'regression',
    trainBudgetMilliNodeHours: 8000,
    disableEarlyStopping: false,
    optimizationObjective: 'minimize-rmse',
  });
  const trainingTaskInputs = trainingTaskInputsObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {
    datasetId: datasetId,
    fractionSplit: {
      trainingFraction: 0.8,
      validationFraction: 0.1,
      testFraction: 0.1,
    },
  };
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline tabular regression response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineTablesRegression();

Python

Prima di provare questo esempio, segui le istruzioni di configurazione Python riportate nella guida rapida all'utilizzo delle librerie client di Vertex AI. Per ulteriori informazioni, consulta la documentazione di riferimento dell'API Python di Vertex AI.

Per autenticarti in Vertex AI, configura le Credenziali predefinite dell'applicazione. Per ulteriori informazioni, consulta Configurare l'autenticazione per un ambiente di sviluppo locale.

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value


def create_training_pipeline_tabular_regression_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    model_display_name: str,
    target_column: str,
    location: str = "us-central1",
    api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": api_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform.gapic.PipelineServiceClient(client_options=client_options)
    # set the columns used for training and their data types
    transformations = [
        {"auto": {"column_name": "STRING_5000unique_NULLABLE"}},
        {"auto": {"column_name": "INTEGER_5000unique_NULLABLE"}},
        {"auto": {"column_name": "FLOAT_5000unique_NULLABLE"}},
        {"auto": {"column_name": "FLOAT_5000unique_REPEATED"}},
        {"auto": {"column_name": "NUMERIC_5000unique_NULLABLE"}},
        {"auto": {"column_name": "BOOLEAN_2unique_NULLABLE"}},
        {
            "timestamp": {
                "column_name": "TIMESTAMP_1unique_NULLABLE",
                "invalid_values_allowed": True,
            }
        },
        {"auto": {"column_name": "DATE_1unique_NULLABLE"}},
        {"auto": {"column_name": "TIME_1unique_NULLABLE"}},
        {
            "timestamp": {
                "column_name": "DATETIME_1unique_NULLABLE",
                "invalid_values_allowed": True,
            }
        },
        {"auto": {"column_name": "STRUCT_NULLABLE.STRING_5000unique_NULLABLE"}},
        {"auto": {"column_name": "STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE"}},
        {"auto": {"column_name": "STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE"}},
        {"auto": {"column_name": "STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED"}},
        {"auto": {"column_name": "STRUCT_NULLABLE.FLOAT_5000unique_REPEATED"}},
        {"auto": {"column_name": "STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE"}},
        {"auto": {"column_name": "STRUCT_NULLABLE.BOOLEAN_2unique_NULLABLE"}},
        {"auto": {"column_name": "STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE"}},
    ]

    training_task_inputs_dict = {
        # required inputs
        "targetColumn": target_column,
        "predictionType": "regression",
        "transformations": transformations,
        "trainBudgetMilliNodeHours": 8000,
        # optional inputs
        "disableEarlyStopping": False,
        # supported regression optimisation objectives: minimize-rmse,
        # minimize-mae, minimize-rmsle
        "optimizationObjective": "minimize-rmse",
    }
    training_task_inputs = json_format.ParseDict(training_task_inputs_dict, Value())

    training_pipeline = {
        "display_name": display_name,
        "training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
        "training_task_inputs": training_task_inputs,
        "input_data_config": {
            "dataset_id": dataset_id,
            "fraction_split": {
                "training_fraction": 0.8,
                "validation_fraction": 0.1,
                "test_fraction": 0.1,
            },
        },
        "model_to_upload": {"display_name": model_display_name},
    }
    parent = f"projects/{project}/locations/{location}"
    response = client.create_training_pipeline(
        parent=parent, training_pipeline=training_pipeline
    )
    print("response:", response)

Passaggi successivi

Per cercare e filtrare gli esempi di codice per altri prodotti Google Cloud, consulta il browser di esempi di Google Cloud.