Entrena un modelo de regresión o clasificación

En esta página, se muestra cómo entrenar un modelo de clasificación o regresión a partir de un conjunto de datos tabular mediante la consola de Goggle Cloud o la API de Vertex AI.

Antes de comenzar

Antes de entrenar un modelo, debes completar lo siguiente:

Entrenar un modelo

Consola de Google Cloud

  1. En la sección Vertex AI de la consola de Google Cloud, ve a la página Conjuntos de datos.

    Ir a la página Conjuntos de datos

  2. Haz clic en el nombre del conjunto de datos que deseas usar para entrenar tu modelo a fin de abrir su página de detalles.

  3. Si tu tipo de datos usa conjuntos de anotaciones, selecciona el conjunto de anotaciones que deseas usar para este modelo.

  4. Haga clic en Entrenar un modelo nuevo.

  5. Selecciona Otros.

  6. En la página Entrenar un modelo nuevo, completa los siguientes pasos:

    1. Selecciona el método de entrenamiento de modelos.

      • AutoML es una buena opción para una amplia gama de casos de uso.

      Haz clic en Continuar.

    2. Ingresa el nombre visible de tu modelo nuevo.

    3. Selecciona la columna objetivo.

      La columna objetivo es el valor que el modelo predecirá.

      Obtén más información sobre los requisitos de las columnas de destino.

    4. Opcional: Para exportar tu conjunto de datos de prueba a BigQuery, marca Exportar conjunto de datos de prueba a BigQuery y proporciona el nombre de la tabla.

    5. Opcional: Para elegir cómo dividir los datos entre conjuntos de entrenamiento, prueba y validación, abre las Opciones avanzadas. Puedes elegir entre las siguientes opciones de división de datos:

      • Aleatoria (predeterminada): Vertex AI selecciona de forma aleatoria las filas asociadas con cada conjunto de datos. De forma predeterminada, Vertex AI selecciona el 80% de tus filas de datos para el conjunto de entrenamiento, el 10% para el conjunto de validación y el 10% para el conjunto de prueba.
      • Manual: Vertex AI selecciona filas de datos para cada uno de los conjuntos de datos según los valores de una columna de división de datos. Proporciona el nombre de la columna de división de datos.
      • Cronológico: Vertex AI divide los datos en función de la marca de tiempo en una columna de tiempo. Proporciona el nombre de la columna de tiempo.

      Obtén más información sobre las divisiones de datos.

    6. Haz clic en Continuar.

    7. Opcional: Haz clic en Generar estadísticas. La generación de estadísticas propaga los menús desplegables de Transformación.

    8. En la página Opciones de entrenamiento, revisa tu lista de columnas y excluye del entrenamiento todas las que no se deban usar para entrenar el modelo.

    9. Revisa las transformaciones seleccionadas para los atributos incluidos, junto con la posibilidad de permitir datos no válidos y realiza las actualizaciones necesarias.

      Obtén más información sobre las transformaciones y los datos no válidos.

    10. Si deseas especificar una columna de peso o cambiar tu objetivo de optimización del valor predeterminado, abre las Opciones avanzadas y realiza tus selecciones.

      Obtén más información sobre las columnas de ponderación y los objetivos de optimización.

    11. Haz clic en Continuar.

    12. En la página Procesamiento y precios, realiza la configuración de la siguiente manera:

      Ingresa el número máximo de horas para las que deseas que se entrene el modelo.

      Esta configuración te ayuda a limitar los costos de entrenamiento. El tiempo real transcurrido puede ser más largo que este valor, ya que hay otras operaciones involucradas en la creación de un modelo nuevo.

      El tiempo de entrenamiento sugerido se relaciona con el tamaño de los datos de entrenamiento. En la siguiente tabla, se muestran los intervalos de tiempo de entrenamiento sugeridos por conteo de filas; una gran cantidad de columnas también aumentará el tiempo de entrenamiento.

      Filas Tiempo de entrenamiento sugerido
      Menor que 100,000 1-3 horas
      100,000-1,000,000 1-6 horas
      1,000,000-10,000,000 1-12 horas
      Más de 10,000,000 De 3 a 24 horas
      Para obtener información sobre los precios de entrenamiento, consulta la página de precios.

    13. Haz clic en Comenzar entrenamiento.

      El entrenamiento de modelos puede tardar muchas horas, según el tamaño y la complejidad de tus datos y tu presupuesto de entrenamiento, si especificaste uno. Puedes cerrar esta pestaña y regresar a ella más tarde. Recibirás un correo electrónico cuando tu modelo haya finalizado el entrenamiento.

API

Selecciona un objetivo de tipo de datos tabulares.

Clasificación

Selecciona una pestaña para tu idioma o entorno:

REST

Usa el comando trainingPipelines.create para entrenar un modelo.

Entrena el modelo.

Antes de usar cualquiera de los datos de solicitud a continuación, realiza los siguientes reemplazos:

  • LOCATION: Tu región.
  • PROJECT: El ID del proyecto.
  • TRAININGPIPELINE_DISPLAY_NAME: El nombre visible de la canalización de entrenamiento creada para esta operación.
  • TARGET_COLUMN: La columna (valor) que deseas que prediga este modelo.
  • WEIGHT_COLUMN: la columna de ponderación (opcional). Obtén más información.
  • TRAINING_BUDGET: la cantidad máxima de tiempo que deseas que se entrene el modelo, en milihoras de procesamiento de nodos (1,000 milihoras de procesamiento de nodos equivalen a una hora de procesamiento de nodos).
  • OPTIMIZATION_OBJECTIVE: es necesario solo si no deseas el objetivo de optimización predeterminado para tu tipo de predicción. Obtén más información.
  • TRANSFORMATION_TYPE: El tipo de transformación se proporciona para cada columna que se usa a fin de entrenar el modelo. Obtén más información.
  • COLUMN_NAME: El nombre de la columna con el tipo de transformación especificado. Se debe especificar cada columna que se usa para entrenar el modelo.
  • MODEL_DISPLAY_NAME: Nombre visible del modelo recién entrenado.
  • DATASET_ID: ID del conjunto de datos de entrenamiento.
  • Puedes proporcionar un objeto Split para controlar tu división de datos. Si deseas obtener información para controlar la división de datos, consulta Controla la división de datos mediante REST.
  • PROJECT_NUMBER: el número de proyecto de tu proyecto generado de forma automática.

Método HTTP y URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Cuerpo JSON de la solicitud:

{
    "displayName": "TRAININGPIPELINE_DISPLAY_NAME",
    "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
    "trainingTaskInputs": {
        "targetColumn": "TARGET_COLUMN",
        "weightColumn": "WEIGHT_COLUMN",
        "predictionType": "classification",
        "trainBudgetMilliNodeHours": TRAINING_BUDGET,
        "optimizationObjective": "OPTIMIZATION_OBJECTIVE",
        "transformations": [
            {"TRANSFORMATION_TYPE_1":  {"column_name" : "COLUMN_NAME_1"} },
            {"TRANSFORMATION_TYPE_2":  {"column_name" : "COLUMN_NAME_2"} },
            ...
    },
    "modelToUpload": {"displayName": "MODEL_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
    }
}

Para enviar tu solicitud, expande una de estas opciones:

Deberías recibir una respuesta JSON similar a la que se muestra a continuación:

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/4567",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-08-18T01:22:57.479336Z",
  "updateTime": "2020-08-18T01:22:57.479336Z"
}

Java

Antes de probar este ejemplo, sigue las instrucciones de configuración para Java incluidas en la guía de inicio rápido de Vertex AI sobre cómo usar bibliotecas cliente. Para obtener más información, consulta la documentación de referencia de la API de Vertex AI Java.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Si deseas obtener más información, consulta Configura la autenticación para un entorno de desarrollo local.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation;
import com.google.rpc.Status;
import java.io.IOException;
import java.util.ArrayList;

public class CreateTrainingPipelineTabularClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "YOUR_PROJECT_ID";
    String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
    String datasetId = "YOUR_DATASET_ID";
    String targetColumn = "TARGET_COLUMN";
    createTrainingPipelineTableClassification(project, modelDisplayName, datasetId, targetColumn);
  }

  static void createTrainingPipelineTableClassification(
      String project, String modelDisplayName, String datasetId, String targetColumn)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";

      // Set the columns used for training and their data types
      Transformation transformation1 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_width").build())
              .build();
      Transformation transformation2 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_length").build())
              .build();
      Transformation transformation3 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("petal_length").build())
              .build();
      Transformation transformation4 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("petal_width").build())
              .build();

      ArrayList<Transformation> transformationArrayList = new ArrayList<>();
      transformationArrayList.add(transformation1);
      transformationArrayList.add(transformation2);
      transformationArrayList.add(transformation3);
      transformationArrayList.add(transformation4);

      AutoMlTablesInputs autoMlTablesInputs =
          AutoMlTablesInputs.newBuilder()
              .setTargetColumn(targetColumn)
              .setPredictionType("classification")
              .addAllTransformations(transformationArrayList)
              .setTrainBudgetMilliNodeHours(8000)
              .build();

      FractionSplit fractionSplit =
          FractionSplit.newBuilder()
              .setTrainingFraction(0.8)
              .setValidationFraction(0.1)
              .setTestFraction(0.1)
              .build();

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder()
              .setDatasetId(datasetId)
              .setFractionSplit(fractionSplit)
              .build();
      Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();

      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(modelDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlTablesInputs))
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(modelToUpload)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Tabular Classification Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
      System.out.format(
          "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());

      System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
      System.out.format(
          "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());

      FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format(
          "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());
      System.out.format(
          "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());

      FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList().toString());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList().toString());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList().toString());

      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());
      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();

      System.out.println("\tPredict Schemata");
      System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format(
          "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format(
          "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (Model.ExportFormat supportedExportFormat :
          modelResponse.getSupportedExportFormatsList()) {
        System.out.println("\tSupported Export Format");
        System.out.format("\t\tId: %s\n", supportedExportFormat.getId());
      }
      ModelContainerSpec containerSpec = modelResponse.getContainerSpec();

      System.out.println("\tContainer Spec");
      System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());
      System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());
      System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());
      System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());
      System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());

      for (EnvVar envVar : containerSpec.getEnvList()) {
        System.out.println("\t\tEnv");
        System.out.format("\t\t\tName: %s\n", envVar.getName());
        System.out.format("\t\t\tValue: %s\n", envVar.getValue());
      }

      for (Port port : containerSpec.getPortsList()) {
        System.out.println("\t\tPort");
        System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("\tDeployed Model");
        System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

Antes de probar este ejemplo, sigue las instrucciones de configuración para Node.js incluidas en la guía de inicio rápido de Vertex AI sobre cómo usar bibliotecas cliente. Para obtener más información, consulta la documentación de referencia de la API de Vertex AI Node.js.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Si deseas obtener más información, consulta Configura la autenticación para un entorno de desarrollo local.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const targetColumn = 'YOUR_TARGET_COLUMN';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;
// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineTablesClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const transformations = [
    {auto: {column_name: 'sepal_width'}},
    {auto: {column_name: 'sepal_length'}},
    {auto: {column_name: 'petal_length'}},
    {auto: {column_name: 'petal_width'}},
  ];
  const trainingTaskInputsObj = new definition.AutoMlTablesInputs({
    targetColumn: targetColumn,
    predictionType: 'classification',
    transformations: transformations,
    trainBudgetMilliNodeHours: 8000,
    disableEarlyStopping: false,
    optimizationObjective: 'minimize-log-loss',
  });
  const trainingTaskInputs = trainingTaskInputsObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {
    datasetId: datasetId,
    fractionSplit: {
      trainingFraction: 0.8,
      validationFraction: 0.1,
      testFraction: 0.1,
    },
  };
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline tabular classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineTablesClassification();

Python

Si deseas obtener información para instalar o actualizar el SDK de Vertex AI para Python, consulta Instala el SDK de Vertex AI para Python. Si deseas obtener más información, consulta la documentación de referencia de la API de Python.

def create_training_pipeline_tabular_classification_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    location: str = "us-central1",
    model_display_name: str = None,
    target_column: str = "target_column",
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    tabular_classification_job = aiplatform.AutoMLTabularTrainingJob(
        display_name=display_name, optimization_prediction_type="classification"
    )

    my_tabular_dataset = aiplatform.TabularDataset(dataset_name=dataset_id)

    model = tabular_classification_job.run(
        dataset=my_tabular_dataset,
        target_column=target_column,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        model_display_name=model_display_name,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Regresión

Selecciona una pestaña para tu idioma o entorno:

REST

Usa el comando trainingPipelines.create para entrenar un modelo.

Entrena el modelo.

Antes de usar cualquiera de los datos de solicitud a continuación, realiza los siguientes reemplazos:

  • LOCATION: Tu región.
  • PROJECT: El ID del proyecto.
  • TRAININGPIPELINE_DISPLAY_NAME: El nombre visible de la canalización de entrenamiento creada para esta operación.
  • TARGET_COLUMN: La columna (valor) que deseas que prediga este modelo.
  • WEIGHT_COLUMN: la columna de ponderación (opcional). Obtén más información.
  • TRAINING_BUDGET: la cantidad máxima de tiempo que deseas que se entrene el modelo, en milihoras de procesamiento de nodos (1,000 milihoras de procesamiento de nodos equivalen a una hora de procesamiento de nodos).
  • OPTIMIZATION_OBJECTIVE: es necesario solo si no deseas el objetivo de optimización predeterminado para tu tipo de predicción. Obtén más información.
  • TRANSFORMATION_TYPE: El tipo de transformación se proporciona para cada columna que se usa a fin de entrenar el modelo. Obtén más información.
  • COLUMN_NAME: El nombre de la columna con el tipo de transformación especificado. Se debe especificar cada columna que se usa para entrenar el modelo.
  • MODEL_DISPLAY_NAME: Nombre visible del modelo recién entrenado.
  • DATASET_ID: ID del conjunto de datos de entrenamiento.
  • Puedes proporcionar un objeto Split para controlar tu división de datos. Si deseas obtener información para controlar la división de datos, consulta Controla la división de datos mediante REST.
  • PROJECT_NUMBER: el número de proyecto de tu proyecto generado de forma automática.

Método HTTP y URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Cuerpo JSON de la solicitud:

{
    "displayName": "TRAININGPIPELINE_DISPLAY_NAME",
    "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
    "trainingTaskInputs": {
        "targetColumn": "TARGET_COLUMN",
        "weightColumn": "WEIGHT_COLUMN",
        "predictionType": "regression",
        "trainBudgetMilliNodeHours": TRAINING_BUDGET,
        "optimizationObjective": "OPTIMIZATION_OBJECTIVE",
        "transformations": [
            {"TRANSFORMATION_TYPE_1":  {"column_name" : "COLUMN_NAME_1"} },
            {"TRANSFORMATION_TYPE_2":  {"column_name" : "COLUMN_NAME_2"} },
            ...
    },
    "modelToUpload": {"displayName": "MODEL_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
    }
}

Para enviar tu solicitud, expande una de estas opciones:

Deberías recibir una respuesta JSON similar a la que se muestra a continuación:

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/4567",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-08-18T01:22:57.479336Z",
  "updateTime": "2020-08-18T01:22:57.479336Z"
}

Java

Antes de probar este ejemplo, sigue las instrucciones de configuración para Java incluidas en la guía de inicio rápido de Vertex AI sobre cómo usar bibliotecas cliente. Para obtener más información, consulta la documentación de referencia de la API de Vertex AI Java.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Si deseas obtener más información, consulta Configura la autenticación para un entorno de desarrollo local.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.TimestampTransformation;
import com.google.rpc.Status;
import java.io.IOException;
import java.util.ArrayList;

public class CreateTrainingPipelineTabularRegressionSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "YOUR_PROJECT_ID";
    String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
    String datasetId = "YOUR_DATASET_ID";
    String targetColumn = "TARGET_COLUMN";
    createTrainingPipelineTableRegression(project, modelDisplayName, datasetId, targetColumn);
  }

  static void createTrainingPipelineTableRegression(
      String project, String modelDisplayName, String datasetId, String targetColumn)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";

      // Set the columns used for training and their data types
      ArrayList<Transformation> tranformations = new ArrayList<>();
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("STRING_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("INTEGER_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_REPEATED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("NUMERIC_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("BOOLEAN_2unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setTimestamp(
                  TimestampTransformation.newBuilder()
                      .setColumnName("TIMESTAMP_1unique_NULLABLE")
                      .setInvalidValuesAllowed(true))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("DATE_1unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("TIME_1unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setTimestamp(
                  TimestampTransformation.newBuilder()
                      .setColumnName("DATETIME_1unique_NULLABLE")
                      .setInvalidValuesAllowed(true))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.STRING_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE"))
              .build());

      AutoMlTablesInputs trainingTaskInputs =
          AutoMlTablesInputs.newBuilder()
              .addAllTransformations(tranformations)
              .setTargetColumn(targetColumn)
              .setPredictionType("regression")
              .setTrainBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              // supported regression optimisation objectives: minimize-rmse,
              // minimize-mae, minimize-rmsle
              .setOptimizationObjective("minimize-rmse")
              .build();

      FractionSplit fractionSplit =
          FractionSplit.newBuilder()
              .setTrainingFraction(0.8)
              .setValidationFraction(0.1)
              .setTestFraction(0.1)
              .build();

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder()
              .setDatasetId(datasetId)
              .setFractionSplit(fractionSplit)
              .build();
      Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();

      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(modelDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(modelToUpload)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Tabular Regression Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
      System.out.format(
          "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());

      System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
      System.out.format(
          "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());

      FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format(
          "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());
      System.out.format(
          "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());

      FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList().toString());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList().toString());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList().toString());

      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());
      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();

      System.out.println("\tPredict Schemata");
      System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format(
          "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format(
          "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (Model.ExportFormat supportedExportFormat :
          modelResponse.getSupportedExportFormatsList()) {
        System.out.println("\tSupported Export Format");
        System.out.format("\t\tId: %s\n", supportedExportFormat.getId());
      }
      ModelContainerSpec containerSpec = modelResponse.getContainerSpec();

      System.out.println("\tContainer Spec");
      System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());
      System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());
      System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());
      System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());
      System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());

      for (EnvVar envVar : containerSpec.getEnvList()) {
        System.out.println("\t\tEnv");
        System.out.format("\t\t\tName: %s\n", envVar.getName());
        System.out.format("\t\t\tValue: %s\n", envVar.getValue());
      }

      for (Port port : containerSpec.getPortsList()) {
        System.out.println("\t\tPort");
        System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("\tDeployed Model");
        System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

Antes de probar este ejemplo, sigue las instrucciones de configuración para Node.js incluidas en la guía de inicio rápido de Vertex AI sobre cómo usar bibliotecas cliente. Para obtener más información, consulta la documentación de referencia de la API de Vertex AI Node.js.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Si deseas obtener más información, consulta Configura la autenticación para un entorno de desarrollo local.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const targetColumn = 'YOUR_TARGET_COLUMN';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;
// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineTablesRegression() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const transformations = [
    {auto: {column_name: 'STRING_5000unique_NULLABLE'}},
    {auto: {column_name: 'INTEGER_5000unique_NULLABLE'}},
    {auto: {column_name: 'FLOAT_5000unique_NULLABLE'}},
    {auto: {column_name: 'FLOAT_5000unique_REPEATED'}},
    {auto: {column_name: 'NUMERIC_5000unique_NULLABLE'}},
    {auto: {column_name: 'BOOLEAN_2unique_NULLABLE'}},
    {
      timestamp: {
        column_name: 'TIMESTAMP_1unique_NULLABLE',
        invalid_values_allowed: true,
      },
    },
    {auto: {column_name: 'DATE_1unique_NULLABLE'}},
    {auto: {column_name: 'TIME_1unique_NULLABLE'}},
    {
      timestamp: {
        column_name: 'DATETIME_1unique_NULLABLE',
        invalid_values_allowed: true,
      },
    },
    {auto: {column_name: 'STRUCT_NULLABLE.STRING_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_REPEATED'}},
    {auto: {column_name: 'STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.BOOLEAN_2unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE'}},
  ];

  const trainingTaskInputsObj = new definition.AutoMlTablesInputs({
    transformations,
    targetColumn,
    predictionType: 'regression',
    trainBudgetMilliNodeHours: 8000,
    disableEarlyStopping: false,
    optimizationObjective: 'minimize-rmse',
  });
  const trainingTaskInputs = trainingTaskInputsObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {
    datasetId: datasetId,
    fractionSplit: {
      trainingFraction: 0.8,
      validationFraction: 0.1,
      testFraction: 0.1,
    },
  };
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline tabular regression response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineTablesRegression();

Python

Si deseas obtener información para instalar o actualizar el SDK de Vertex AI para Python, consulta Instala el SDK de Vertex AI para Python. Si deseas obtener más información, consulta la documentación de referencia de la API de Python.

def create_training_pipeline_tabular_regression_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    location: str = "us-central1",
    model_display_name: str = "my_model",
    target_column: str = "target_column",
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    tabular_regression_job = aiplatform.AutoMLTabularTrainingJob(
        display_name=display_name, optimization_prediction_type="regression"
    )

    my_tabular_dataset = aiplatform.TabularDataset(dataset_name=dataset_id)

    model = tabular_regression_job.run(
        dataset=my_tabular_dataset,
        target_column=target_column,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        model_display_name=model_display_name,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Controla la división de datos mediante REST

Puedes controlar cómo se dividen los datos de entrenamiento entre los conjuntos de entrenamiento, validación y prueba. Cuando uses la API de Vertex AI, usa el objeto Split para determinar la división de datos. El objeto Split se puede incluir en el objeto inputDataConfig como uno de varios tipos de objeto, cada uno de los cuales proporciona una forma diferente de dividir los datos de entrenamiento.

Los métodos que puedes usar para dividir tus datos dependen de tu tipo de datos:

  • FractionSplit:

    • TRAINING_FRACTION: La fracción de los datos de entrenamiento que se usarán para el conjunto de entrenamiento.
    • VALIDATION_FRACTION: La fracción de los datos de entrenamiento que se usarán para el conjunto de validación.
    • TEST_FRACTION: La fracción de los datos de entrenamiento que se usarán para el conjunto de prueba.

    Si se especifican cualquiera de las fracciones, se deben especificar todas. Las fracciones deben sumar hasta 1.0. Obtén más información.

    "fractionSplit": {
    "trainingFraction": TRAINING_FRACTION,
    "validationFraction": VALIDATION_FRACTION,
    "testFraction": TEST_FRACTION
    },
    

  • PredefinedSplit:

    • DATA_SPLIT_COLUMN: La columna que contiene los valores de división de datos (TRAIN, VALIDATION y TEST).

    Especifica manualmente la división de datos para cada fila mediante una columna dividida. Obtén más información.

    "predefinedSplit": {
      "key": DATA_SPLIT_COLUMN
    },
    
  • TimestampSplit:

    • TRAINING_FRACTION: El porcentaje de los datos de entrenamiento que se usará para el conjunto de entrenamiento. El valor predeterminado es 0.80.
    • VALIDATION_FRACTION: El porcentaje de los datos de entrenamiento que se usará para el conjunto de validación. El valor predeterminado es 0.10.
    • TEST_FRACTION: El porcentaje de los datos de entrenamiento que se usarán para el conjunto de prueba. El valor predeterminado es 0.10.
    • TIME_COLUMN: La columna que contiene las marcas de tiempo.

    Si se especifican cualquiera de las fracciones, se deben especificar todas. Las fracciones deben sumar 1.0. Obtén más información.

    "timestampSplit": {
      "trainingFraction": TRAINING_FRACTION,
      "validationFraction": VALIDATION_FRACTION,
      "testFraction": TEST_FRACTION,
      "key": TIME_COLUMN
    }
    

Objetivos de optimización para los modelos de regresión o clasificación

Cuando entrenas un modelo, Vertex AI selecciona un objetivo de optimización predeterminado según el tipo de modelo y el tipo de datos que se usa para la columna objetivo.

Los modelos de clasificación son mejores para los siguientes casos:
Objetivo de optimización Valor de la API Usa este objetivo si quieres…
AUC ROC maximize-au-roc Maximizar área bajo la curva de característica operativa del receptor (ROC). Distingue las clases. Valor predeterminado para la clasificación binaria.
Pérdida logística minimize-log-loss Mantener las probabilidades de predicción lo más precisas posible. Solo es compatible con la clasificación de clases múltiples.
AUC PR maximize-au-prc Maximizar área debajo de la curva de precisión-recuperación. Optimiza los resultados para las predicciones de la clase menos común.
Precisión en recuperación maximize-precision-at-recall Optimizar la precisión en un valor de recuperación específico.
Recuperación en precisión maximize-recall-at-precision Optimizar la recuperación con un valor de precisión específico.
Los modelos de regresión son mejores para los siguientes casos:
Objetivo de optimización Valor de la API Usa este objetivo si quieres…
RMSE minimize-rmse Minimiza raíz cuadrada del error cuadrático medio (RMSE). Captura valores más extremos con exactitud. Valor predeterminado
MAE minimize-mae Minimizar el error absoluto promedio (MAE) Observa los valores extremos como valores atípicos con un impacto menor en el modelo.
RMSLE minimize-rmsle Minimizar el error de registro de la raíz cuadrada de la media (RMSLE) Penaliza errores de tamaño relativo en lugar de valor absoluto. Es útil cuando los valores previstos y reales pueden ser bastante grandes.

¿Qué sigue?