Entraîner un modèle AutoML (API Vertex AI)

Cette page explique comment entraîner un modèle AutoML à l'aide de l'API d'IA Vertex.

Pour en savoir plus sur l'utilisation de la console Google Cloud afin d'entraîner un modèle AutoML, consultez la page Entraîner un modèle AutoML à l'aide de Google Cloud Console.

Avant de commencer

Avant d'entraîner un modèle, vous devez préparer vos données d'entraînement et créer un ensemble de données.

Entraîner un modèle AutoML à l'aide de l'API

Lorsque vous entraînez un modèle à l'aide de l'API, vous créez un objet TrainingPipeline, en spécifiant l'ensemble de données contenant vos données d'entraînement.

Sélectionnez votre type de données ci-dessous :

Image

Sélectionnez l'onglet correspondant à votre objectif :

Classification

Sélectionnez l'onglet correspondant à votre langage ou à votre environnement :

API REST et ligne de commande

Avant d'utiliser les données de requête ci-dessous, effectuez les remplacements suivants :

  • LOCATION : région d'emplacement de l'ensemble de données et de création du modèle. Exemple : us-central1.
  • PROJECT : ID de votre projet
  • TRAININGPIPELINE_DISPLAYNAME : valeur obligatoire. Nom à afficher pour le trainingPipeline.
  • DATASET_ID : ID de l'ensemble de données à utiliser pour l'entraînement.
  • fractionSplit : facultatif. Une des nombreuses options de répartition possibles en cas d'utilisation de ML pour vos données. Pour fractionSplit, les valeurs doivent être égales à 1. Par exemple :
    • {"trainingFraction": "0.7","validationFraction": "0.15","testFraction": "0.15"}
  • MODEL_DISPLAYNAME* : nom à afficher pour le modèle importé (créé) par le TrainingPipeline.
  • MODEL_DESCRIPTION* : description du modèle.
  • modelToUpload.labels* : tout ensemble de paires clé/valeur pour organiser vos modèles. Exemple :
    • "env" : "prod"
    • "tier" : "backend"
  • MODELTYPE : type de modèle hébergé dans le cloud à entraîner. Vous disposez des options suivantes :
    • CLOUD (par défaut)
  • NODE_HOUR_BUDGET : le coût réel de l'entraînement sera égal ou inférieur à cette valeur. Pour les modèles cloud, le budget doit être compris entre 8 000 et 800 000 milli-nœuds-heure (inclus). La valeur par défaut est de 192 000, ce qui correspond à une durée d'exécution d'une journée, en supposant que 8 nœuds sont utilisés.
  • PROJECT_NUMBER : numéro de votre projet

Méthode HTTP et URL :

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Corps JSON de la requête :

{
  "displayName": "TRAININGPIPELINE_DISPLAYNAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID",
    "fractionSplit": {
      "trainingFraction": "DECIMAL",
      "validationFraction": "DECIMAL",
      "testFraction": "DECIMAL"
    }
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAYNAME",
    "description": "MODEL_DESCRIPTION",
    "labels": {
      "KEY": "VALUE"
    }
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml",
  "trainingTaskInputs": {
    "multiLabel": "false",
    "modelType": ["MODELTYPE"],
    "budgetMilliNodeHours": NODE_HOUR_BUDGET
  }
}

Pour envoyer votre requête, choisissez l'une des options suivantes :

curl

Enregistrez le corps de la requête dans un fichier nommé request.json, puis exécutez la commande suivante :

curl -X POST \
-H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines"

PowerShell

Enregistrez le corps de la requête dans un fichier nommé request.json, puis exécutez la commande suivante :

$cred = gcloud auth application-default print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines" | Select-Object -Expand Content

La réponse contient des informations sur les spécifications, ainsi que sur TRAININGPIPELINE_ID.

Java

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Java.

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineImageClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    createTrainingPipelineImageClassificationSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineImageClassificationSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_image_classification_1.0.0.yaml";
      LocationName locationName = LocationName.of(project, location);

      AutoMlImageClassificationInputs autoMlImageClassificationInputs =
          AutoMlImageClassificationInputs.newBuilder()
              .setModelType(ModelType.CLOUD)
              .setMultiLabel(false)
              .setBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              .build();

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Image Classification Response");
      System.out.format("Name: %s\n", trainingPipelineResponse.getName());
      System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("Input Data Config");
      System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("Fraction Split");
      System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("Filter Split");
      System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("Predefined Split");
      System.out.format("Key: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("Timestamp Split");
      System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("Key: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("Model To Upload");
      System.out.format("Name: %s\n", modelResponse.getName());
      System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
      System.out.format("Description: %s\n", modelResponse.getDescription());

      System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("Metadata: %s\n", modelResponse.getMetadata());
      System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "Supported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "Supported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "Supported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
      System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("Labels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("Predict Schemata");
      System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("Supported Export Format");
        System.out.format("Id: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("Container Spec");
      System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("Command: %s\n", modelContainerSpec.getCommandList());
      System.out.format("Args: %s\n", modelContainerSpec.getArgsList());
      System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("Env");
        System.out.format("Name: %s\n", envVar.getName());
        System.out.format("Value: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("Port");
        System.out.format("Container Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("Deployed Model");
        System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("Error");
      System.out.format("Code: %s\n", status.getCode());
      System.out.format("Message: %s\n", status.getMessage());
    }
  }
}

Node.js

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Node.js.

/**
 * TODO(developer): Uncomment these variables before running the sample.
 * (Not necessary if passing values as arguments)
 */
/*
const datasetId = 'YOUR DATASET';
const modelDisplayName = 'NEW MODEL NAME;
const trainingPipelineDisplayName = 'NAME FOR TRAINING PIPELINE';
const project = 'YOUR PROJECT ID';
const location = 'us-central1';
  */
// Imports the Google Cloud Pipeline Service Client library
const aiplatform = require('@google-cloud/aiplatform');

const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;
const ModelType = definition.AutoMlImageClassificationInputs.ModelType;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const {PipelineServiceClient} = aiplatform.v1;
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineImageClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  // Values should match the input expected by your model.
  const trainingTaskInputsMessage =
    new definition.AutoMlImageClassificationInputs({
      multiLabel: true,
      modelType: ModelType.CLOUD,
      budgetMilliNodeHours: 8000,
      disableEarlyStopping: false,
    });

  const trainingTaskInputs = trainingTaskInputsMessage.toValue();

  const trainingTaskDefinition =
    'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml';

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition,
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {parent, trainingPipeline};

  // Create training pipeline request
  const [response] = await pipelineServiceClient.createTrainingPipeline(
    request
  );

  console.log('Create training pipeline image classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}

createTrainingPipelineImageClassification();

Python

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Python.

def create_training_pipeline_image_classification_sample(
    project: str,
    location: str,
    display_name: str,
    dataset_id: str,
    model_display_name: Optional[str] = None,
    model_type: str = "CLOUD",
    multi_label: bool = False,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    job = aiplatform.AutoMLImageTrainingJob(
        display_name=display_name,
        model_type=model_type,
        prediction_type="classification",
        multi_label=multi_label,
    )

    my_image_ds = aiplatform.ImageDataset(dataset_id)

    model = job.run(
        dataset=my_image_ds,
        model_display_name=model_display_name,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Classification

Sélectionnez l'onglet correspondant à votre langage ou à votre environnement :

API REST et ligne de commande

Avant d'utiliser les données de requête ci-dessous, effectuez les remplacements suivants :

  • LOCATION : région d'emplacement de l'ensemble de données et de création du modèle. Exemple : us-central1.
  • PROJECT : ID de votre projet
  • TRAININGPIPELINE_DISPLAYNAME : valeur obligatoire. Nom à afficher pour le trainingPipeline.
  • DATASET_ID : ID de l'ensemble de données à utiliser pour l'entraînement.
  • fractionSplit : facultatif. Une des nombreuses options de répartition possibles en cas d'utilisation de ML pour vos données. Pour fractionSplit, les valeurs doivent être égales à 1. Par exemple :
    • {"trainingFraction": "0.7","validationFraction": "0.15","testFraction": "0.15"}
  • MODEL_DISPLAYNAME* : nom à afficher pour le modèle importé (créé) par le TrainingPipeline.
  • MODEL_DESCRIPTION* : description du modèle.
  • modelToUpload.labels* : tout ensemble de paires clé/valeur pour organiser vos modèles. Exemple :
    • "env" : "prod"
    • "tier" : "backend"
  • MODELTYPE : type de modèle hébergé dans le cloud à entraîner. Vous disposez des options suivantes :
    • CLOUD (par défaut)
  • NODE_HOUR_BUDGET : le coût réel de l'entraînement sera égal ou inférieur à cette valeur. Pour les modèles cloud, le budget doit être compris entre 8 000 et 800 000 milli-nœuds-heure (inclus). La valeur par défaut est de 192 000, ce qui correspond à une durée d'exécution d'une journée, en supposant que 8 nœuds sont utilisés.
  • PROJECT_NUMBER : numéro de votre projet

Méthode HTTP et URL :

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Corps JSON de la requête :

{
  "displayName": "TRAININGPIPELINE_DISPLAYNAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID",
    "fractionSplit": {
      "trainingFraction": "DECIMAL",
      "validationFraction": "DECIMAL",
      "testFraction": "DECIMAL"
    }
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAYNAME",
    "description": "MODEL_DESCRIPTION",
    "labels": {
      "KEY": "VALUE"
    }
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml",
  "trainingTaskInputs": {
    "multiLabel": "true",
    "modelType": ["MODELTYPE"],
    "budgetMilliNodeHours": NODE_HOUR_BUDGET
  }
}

Pour envoyer votre requête, choisissez l'une des options suivantes :

curl

Enregistrez le corps de la requête dans un fichier nommé request.json, puis exécutez la commande suivante :

curl -X POST \
-H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines"

PowerShell

Enregistrez le corps de la requête dans un fichier nommé request.json, puis exécutez la commande suivante :

$cred = gcloud auth application-default print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines" | Select-Object -Expand Content

La réponse contient des informations sur les spécifications, ainsi que sur TRAININGPIPELINE_ID.

Java

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Java.

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineImageClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    createTrainingPipelineImageClassificationSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineImageClassificationSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_image_classification_1.0.0.yaml";
      LocationName locationName = LocationName.of(project, location);

      AutoMlImageClassificationInputs autoMlImageClassificationInputs =
          AutoMlImageClassificationInputs.newBuilder()
              .setModelType(ModelType.CLOUD)
              .setMultiLabel(false)
              .setBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              .build();

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Image Classification Response");
      System.out.format("Name: %s\n", trainingPipelineResponse.getName());
      System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("Input Data Config");
      System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("Fraction Split");
      System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("Filter Split");
      System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("Predefined Split");
      System.out.format("Key: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("Timestamp Split");
      System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("Key: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("Model To Upload");
      System.out.format("Name: %s\n", modelResponse.getName());
      System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
      System.out.format("Description: %s\n", modelResponse.getDescription());

      System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("Metadata: %s\n", modelResponse.getMetadata());
      System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "Supported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "Supported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "Supported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
      System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("Labels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("Predict Schemata");
      System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("Supported Export Format");
        System.out.format("Id: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("Container Spec");
      System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("Command: %s\n", modelContainerSpec.getCommandList());
      System.out.format("Args: %s\n", modelContainerSpec.getArgsList());
      System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("Env");
        System.out.format("Name: %s\n", envVar.getName());
        System.out.format("Value: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("Port");
        System.out.format("Container Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("Deployed Model");
        System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("Error");
      System.out.format("Code: %s\n", status.getCode());
      System.out.format("Message: %s\n", status.getMessage());
    }
  }
}

Node.js

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Node.js.

/**
 * TODO(developer): Uncomment these variables before running the sample.
 * (Not necessary if passing values as arguments)
 */
/*
const datasetId = 'YOUR DATASET';
const modelDisplayName = 'NEW MODEL NAME;
const trainingPipelineDisplayName = 'NAME FOR TRAINING PIPELINE';
const project = 'YOUR PROJECT ID';
const location = 'us-central1';
  */
// Imports the Google Cloud Pipeline Service Client library
const aiplatform = require('@google-cloud/aiplatform');

const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;
const ModelType = definition.AutoMlImageClassificationInputs.ModelType;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const {PipelineServiceClient} = aiplatform.v1;
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineImageClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  // Values should match the input expected by your model.
  const trainingTaskInputsMessage =
    new definition.AutoMlImageClassificationInputs({
      multiLabel: true,
      modelType: ModelType.CLOUD,
      budgetMilliNodeHours: 8000,
      disableEarlyStopping: false,
    });

  const trainingTaskInputs = trainingTaskInputsMessage.toValue();

  const trainingTaskDefinition =
    'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml';

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition,
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {parent, trainingPipeline};

  // Create training pipeline request
  const [response] = await pipelineServiceClient.createTrainingPipeline(
    request
  );

  console.log('Create training pipeline image classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}

createTrainingPipelineImageClassification();

Python

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Python.

def create_training_pipeline_image_classification_sample(
    project: str,
    location: str,
    display_name: str,
    dataset_id: str,
    model_display_name: Optional[str] = None,
    model_type: str = "CLOUD",
    multi_label: bool = False,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    job = aiplatform.AutoMLImageTrainingJob(
        display_name=display_name,
        model_type=model_type,
        prediction_type="classification",
        multi_label=multi_label,
    )

    my_image_ds = aiplatform.ImageDataset(dataset_id)

    model = job.run(
        dataset=my_image_ds,
        model_display_name=model_display_name,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Détection d'objets

Sélectionnez l'onglet correspondant à votre langage ou à votre environnement :

API REST et ligne de commande

Avant d'utiliser les données de requête ci-dessous, effectuez les remplacements suivants :

  • LOCATION : région d'emplacement de l'ensemble de données et de création du modèle. Exemple : us-central1.
  • PROJECT : ID de votre projet
  • TRAININGPIPELINE_DISPLAYNAME : valeur obligatoire. Nom à afficher pour le trainingPipeline.
  • DATASET_ID : ID de l'ensemble de données à utiliser pour l'entraînement.
  • fractionSplit : facultatif. Une des nombreuses options de répartition possibles en cas d'utilisation de ML pour vos données. Pour fractionSplit, les valeurs doivent être égales à 1. Par exemple :
    • {"trainingFraction": "0.7","validationFraction": "0.15","testFraction": "0.15"}
  • MODEL_DISPLAYNAME* : nom à afficher pour le modèle importé (créé) par le TrainingPipeline.
  • MODEL_DESCRIPTION* : description du modèle.
  • modelToUpload.labels* : tout ensemble de paires clé/valeur pour organiser vos modèles. Exemple :
    • "env" : "prod"
    • "tier" : "backend"
  • MODELTYPE : type de modèle hébergé dans le cloud à entraîner. Vous disposez des options suivantes :
    • CLOUD-HIGH-ACCURACY-1 : modèle optimal pour une utilisation dans Google Cloud et impossible à exporter. Ce modèle doit générer une latence plus élevée, mais la qualité des prédictions est également censée être supérieure à celle des autres modèles cloud.
    • CLOUD-LOW-LATENCY-1 : modèle optimal pour une utilisation dans Google Cloud et impossible à exporter. Ce modèle doit générer une faible latence, mais la qualité des prédictions peut s'avérer inférieure à celle des autres modèles cloud.
  • NODE_HOUR_BUDGET : le coût réel de l'entraînement sera égal ou inférieur à cette valeur. Pour les modèles cloud, le budget doit être compris entre 20 000 et 900 000 milli-nœuds-heure (inclus). La valeur par défaut est de 216 000, ce qui correspond à une durée d'exécution d'une journée, en supposant que 9 nœuds sont utilisés.
  • PROJECT_NUMBER : numéro de votre projet

Méthode HTTP et URL :

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Corps JSON de la requête :

{
  "displayName": "TRAININGPIPELINE_DISPLAYNAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID",
    "fractionSplit": {
      "trainingFraction": "DECIMAL",
      "validationFraction": "DECIMAL",
      "testFraction": "DECIMAL"
    }
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAYNAME",
    "description": "MODEL_DESCRIPTION",
    "labels": {
      "KEY": "VALUE"
    }
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml",
  "trainingTaskInputs": {
    "modelType": ["MODELTYPE"],
    "budgetMilliNodeHours": NODE_HOUR_BUDGET
  }
}

Pour envoyer votre requête, choisissez l'une des options suivantes :

curl

Enregistrez le corps de la requête dans un fichier nommé request.json, puis exécutez la commande suivante :

curl -X POST \
-H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines"

PowerShell

Enregistrez le corps de la requête dans un fichier nommé request.json, puis exécutez la commande suivante :

$cred = gcloud auth application-default print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines" | Select-Object -Expand Content

La réponse contient des informations sur les spécifications, ainsi que sur TRAININGPIPELINE_ID.

Java

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Java.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageObjectDetectionInputs;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageObjectDetectionInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineImageObjectDetectionSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    createTrainingPipelineImageObjectDetectionSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineImageObjectDetectionSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_image_object_detection_1.0.0.yaml";
      LocationName locationName = LocationName.of(project, location);

      AutoMlImageObjectDetectionInputs autoMlImageObjectDetectionInputs =
          AutoMlImageObjectDetectionInputs.newBuilder()
              .setModelType(ModelType.CLOUD_HIGH_ACCURACY_1)
              .setBudgetMilliNodeHours(20000)
              .setDisableEarlyStopping(false)
              .build();

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlImageObjectDetectionInputs))
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Image Object Detection Response");
      System.out.format("Name: %s\n", trainingPipelineResponse.getName());
      System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("Input Data Config");
      System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("Fraction Split");
      System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("Filter Split");
      System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("Predefined Split");
      System.out.format("Key: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("Timestamp Split");
      System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("Key: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("Model To Upload");
      System.out.format("Name: %s\n", modelResponse.getName());
      System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
      System.out.format("Description: %s\n", modelResponse.getDescription());

      System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("Metadata: %s\n", modelResponse.getMetadata());
      System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "Supported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "Supported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "Supported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
      System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("Labels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("Predict Schemata");
      System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("Supported Export Format");
        System.out.format("Id: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("Container Spec");
      System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("Command: %s\n", modelContainerSpec.getCommandList());
      System.out.format("Args: %s\n", modelContainerSpec.getArgsList());
      System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("Env");
        System.out.format("Name: %s\n", envVar.getName());
        System.out.format("Value: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("Port");
        System.out.format("Container Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("Deployed Model");
        System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("Error");
      System.out.format("Code: %s\n", status.getCode());
      System.out.format("Message: %s\n", status.getMessage());
    }
  }
}

Node.js

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Node.js.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';

const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;
const ModelType = definition.AutoMlImageObjectDetectionInputs.ModelType;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineImageObjectDetection() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const trainingTaskInputsObj =
    new definition.AutoMlImageObjectDetectionInputs({
      disableEarlyStopping: false,
      modelType: ModelType.CLOUD_HIGH_ACCURACY_1,
      budgetMilliNodeHours: 20000,
    });

  const trainingTaskInputs = trainingTaskInputsObj.toValue();
  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId: datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] = await pipelineServiceClient.createTrainingPipeline(
    request
  );

  console.log('Create training pipeline image object detection response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineImageObjectDetection();

Python

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Python.

from google.cloud import aiplatform
from google.cloud.aiplatform.gapic.schema import trainingjob

def create_training_pipeline_image_object_detection_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    model_display_name: str,
    location: str = "us-central1",
    api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": api_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform.gapic.PipelineServiceClient(client_options=client_options)
    training_task_inputs = trainingjob.definition.AutoMlImageObjectDetectionInputs(
        model_type="CLOUD_HIGH_ACCURACY_1",
        budget_milli_node_hours=20000,
        disable_early_stopping=False,
    ).to_value()

    training_pipeline = {
        "display_name": display_name,
        "training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml",
        "training_task_inputs": training_task_inputs,
        "input_data_config": {"dataset_id": dataset_id},
        "model_to_upload": {"display_name": model_display_name},
    }
    parent = f"projects/{project}/locations/{location}"
    response = client.create_training_pipeline(
        parent=parent, training_pipeline=training_pipeline
    )
    print("response:", response)

Tabulaire

Sélectionnez un objectif de type de données tabulaires.

Classification

Sélectionnez un onglet pour votre langage ou environnement :

API REST et ligne de commande

Vous utilisez la commande trainingPipelines.create pour entraîner un modèle.

Entraîner le modèle

Avant d'utiliser les données de requête ci-dessous, effectuez les remplacements suivants :

  • LOCATION : votre région.
  • PROJECT : ID de votre projet
  • TRAININGPIPELINE_DISPLAY_NAME : nom à afficher du pipeline d'entraînement créé pour cette opération.
  • TARGET_COLUMN : colonne (valeur) que le modèle doit prédire.
  • WEIGHT_COLUMN (facultatif) : colonne de pondération. En savoir plus
  • TRAINING_BUDGET : durée maximale pendant laquelle le modèle doit être entraîné, en milli-nœuds-heure (1 000 milli-nœuds-heure correspondent à un nœud-heure).
  • OPTIMIZATION_OBJECTIVE : obligatoire uniquement si vous ne souhaitez pas atteindre l'objectif d'optimisation par défaut pour votre type de prédiction. En savoir plus
  • TRANSFORMATION_TYPE : le type de transformation est fourni pour chaque colonne utilisée pour entraîner le modèle. En savoir plus
  • COLUMN_NAME : nom de la colonne avec le type de transformation spécifié. Chaque colonne utilisée pour entraîner le modèle doit être spécifiée.
  • MODEL_DISPLAY_NAME : nom à afficher du modèle nouvellement entraîné.
  • DATASET_ID : ID de l'ensemble de données d'entraînement.
  • Vous pouvez fournir un objet Split pour contrôler votre répartition des données. Pour en savoir plus sur le contrôle de la répartition des données, consultez la section Contrôler la répartition des données à l'aide de REST.
  • PROJECT_NUMBER : numéro de votre projet

Méthode HTTP et URL :

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Corps JSON de la requête :

{
    "displayName": "TRAININGPIPELINE_DISPLAY_NAME",
    "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
    "trainingTaskInputs": {
        "targetColumn": "TARGET_COLUMN",
        "weightColumn": "WEIGHT_COLUMN",
        "predictionType": "classification",
        "trainBudgetMilliNodeHours": TRAINING_BUDGET,
        "optimizationObjective": "OPTIMIZATION_OBJECTIVE",
        "transformations": [
            {"TRANSFORMATION_TYPE_1":  {"column_name" : "COLUMN_NAME_1"} },
            {"TRANSFORMATION_TYPE_2":  {"column_name" : "COLUMN_NAME_2"} },
            ...
    },
    "modelToUpload": {"displayName": "MODEL_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
    }
}

Pour envoyer votre requête, développez l'une des options suivantes :

Vous devriez recevoir une réponse JSON de ce type :

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/4567",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-08-18T01:22:57.479336Z",
  "updateTime": "2020-08-18T01:22:57.479336Z"
}

Java

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Java.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation;
import com.google.rpc.Status;
import java.io.IOException;
import java.util.ArrayList;

public class CreateTrainingPipelineTabularClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "YOUR_PROJECT_ID";
    String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
    String datasetId = "YOUR_DATASET_ID";
    String targetColumn = "TARGET_COLUMN";
    createTrainingPipelineTableClassification(project, modelDisplayName, datasetId, targetColumn);
  }

  static void createTrainingPipelineTableClassification(
      String project, String modelDisplayName, String datasetId, String targetColumn)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";

      // Set the columns used for training and their data types
      Transformation transformation1 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_width").build())
              .build();
      Transformation transformation2 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_length").build())
              .build();
      Transformation transformation3 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("petal_length").build())
              .build();
      Transformation transformation4 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("petal_width").build())
              .build();

      ArrayList<Transformation> transformationArrayList = new ArrayList<>();
      transformationArrayList.add(transformation1);
      transformationArrayList.add(transformation2);
      transformationArrayList.add(transformation3);
      transformationArrayList.add(transformation4);

      AutoMlTablesInputs autoMlTablesInputs =
          AutoMlTablesInputs.newBuilder()
              .setTargetColumn(targetColumn)
              .setPredictionType("classification")
              .addAllTransformations(transformationArrayList)
              .setTrainBudgetMilliNodeHours(8000)
              .build();

      FractionSplit fractionSplit =
          FractionSplit.newBuilder()
              .setTrainingFraction(0.8)
              .setValidationFraction(0.1)
              .setTestFraction(0.1)
              .build();

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder()
              .setDatasetId(datasetId)
              .setFractionSplit(fractionSplit)
              .build();
      Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();

      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(modelDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlTablesInputs))
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(modelToUpload)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Tabular Classification Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
      System.out.format(
          "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());

      System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
      System.out.format(
          "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());

      FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format(
          "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());
      System.out.format(
          "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());

      FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList().toString());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList().toString());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList().toString());

      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());
      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();

      System.out.println("\tPredict Schemata");
      System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format(
          "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format(
          "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (Model.ExportFormat supportedExportFormat :
          modelResponse.getSupportedExportFormatsList()) {
        System.out.println("\tSupported Export Format");
        System.out.format("\t\tId: %s\n", supportedExportFormat.getId());
      }
      ModelContainerSpec containerSpec = modelResponse.getContainerSpec();

      System.out.println("\tContainer Spec");
      System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());
      System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());
      System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());
      System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());
      System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());

      for (EnvVar envVar : containerSpec.getEnvList()) {
        System.out.println("\t\tEnv");
        System.out.format("\t\t\tName: %s\n", envVar.getName());
        System.out.format("\t\t\tValue: %s\n", envVar.getValue());
      }

      for (Port port : containerSpec.getPortsList()) {
        System.out.println("\t\tPort");
        System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("\tDeployed Model");
        System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Node.js.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const targetColumn = 'YOUR_TARGET_COLUMN';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;
// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineTablesClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const transformations = [
    {auto: {column_name: 'sepal_width'}},
    {auto: {column_name: 'sepal_length'}},
    {auto: {column_name: 'petal_length'}},
    {auto: {column_name: 'petal_width'}},
  ];
  const trainingTaskInputsObj = new definition.AutoMlTablesInputs({
    targetColumn: targetColumn,
    predictionType: 'classification',
    transformations: transformations,
    trainBudgetMilliNodeHours: 8000,
    disableEarlyStopping: false,
    optimizationObjective: 'minimize-log-loss',
  });
  const trainingTaskInputs = trainingTaskInputsObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {
    datasetId: datasetId,
    fractionSplit: {
      trainingFraction: 0.8,
      validationFraction: 0.1,
      testFraction: 0.1,
    },
  };
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] = await pipelineServiceClient.createTrainingPipeline(
    request
  );

  console.log('Create training pipeline tabular classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineTablesClassification();

Python

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Python.

def create_training_pipeline_tabular_classification_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    location: str = "us-central1",
    model_display_name: str = None,
    target_column: str = "target_column",
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    tabular_classification_job = aiplatform.AutoMLTabularTrainingJob(
        display_name=display_name, optimization_prediction_type="classification"
    )

    my_tabular_dataset = aiplatform.TabularDataset(dataset_name=dataset_id)

    model = tabular_classification_job.run(
        dataset=my_tabular_dataset,
        target_column=target_column,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        model_display_name=model_display_name,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Prévision

Sélectionnez un onglet pour votre langage ou environnement :

API REST et ligne de commande

Vous utilisez la commande trainingPipelines.create pour entraîner un modèle.

Entraîner le modèle

Avant d'utiliser les données de requête ci-dessous, effectuez les remplacements suivants :

  • LOCATION : votre région.
  • PROJECT : ID de votre projet
  • TRAINING_PIPELINE_DISPLAY_NAME : nom à afficher du pipeline d'entraînement créé pour cette opération.
  • TRAINING_TASK_DEFINITION : méthode d'entraînement du modèle
    • AutoML : un bon choix pour un large éventail de cas d'utilisation.
      gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_forecasting_1.0.0.yaml
    • Seq2Seq+ : un bon choix pour l'expérimentation. L'algorithme est susceptible de converger plus rapidement qu'AutoML, car son architecture est plus simple et il utilise un espace de recherche plus petit. Nos tests montrent que Seq2Seq+ offre de bons résultats avec un petit budget-temps et des ensembles de données dont la taille est inférieure à 1 Go.
      gs://google-cloud-aiplatform/schema/trainingjob/definition/seq2seq_plus_time_series_forecasting_1.0.0.yaml
  • TARGET_COLUMN : colonne (valeur) que le modèle doit prédire.
  • TIME_COLUMN : colonne Heure. En savoir plus
  • TIME_SERIES_IDENTIFIER_COLUMN : colonne de l'identifiant de série temporelle. En savoir plus
  • WEIGHT_COLUMN (facultatif) : colonne de pondération. En savoir plus
  • TRAINING_BUDGET : durée maximale pendant laquelle le modèle doit être entraîné, en milli-nœuds-heure (1 000 milli-nœuds-heure correspondent à un nœud-heure).
  • GRANULARITY_UNIT : unité à utiliser pour la précision de vos données d'entraînement, de l'horizon de prévision et de la fenêtre de contexte. Il peut s'agir de minute, hour, day, week, month ou year. Sélectionnez day si vous souhaitez utiliser la modélisation des effets des jours fériés. En savoir plus
  • GRANULARITY_QUANTITY : nombre d'unités de précision qui composent l'intervalle entre les observations dans vos données d'entraînement. Doit être égal à un pour toutes les unités sauf les minutes, pouvant correspondre à 1, 5, 10, 15 ou 30. En savoir plus
  • GROUP_COLUMNS : noms de colonne de votre table d'entrée d'entraînement qui identifient le regroupement au niveau de la hiérarchie. La ou les colonnes doivent être "time_series_attribute_columns". Apprenez-en plus.
  • GROUP_TOTAL_WEIGHT : pondération de la perte agrégée du groupe par rapport à la perte individuelle. Désactivée si la valeur est définie sur "0.0" ou n'est pas définie. Si la colonne de groupe n'est pas définie, toutes les séries temporelles seront traitées dans le même groupe et agrégées sur toutes les séries temporelles. En savoir plus
  • TEMPORAL_TOTAL_WEIGHT : pondération de la perte agrégée dans le temps par rapport à la perte individuelle. Désactivée si la valeur est définie sur "0.0" ou n'est pas définie. En savoir plus
  • GROUP_TEMPORAL_TOTAL_WEIGHT : pondération de la perte totale (groupe x temps) par rapport à la perte individuelle. Désactivée si la valeur est définie sur "0.0" ou n'est pas définie. Si la colonne de groupe n'est pas définie, toutes les séries temporelles seront traitées dans le même groupe et agrégées sur toutes les séries temporelles. En savoir plus
  • HOLIDAY_REGIONS : (facultatif) une ou plusieurs régions géographiques en fonction desquelles l'effet des jours fériés est appliqué dans la modélisation. Pendant l'entraînement, Vertex AI crée des caractéristiques catégorielles de jours fériés dans le modèle en fonction de la date de la colonne Heure et des régions géographiques spécifiées. Pour activer cette fonctionnalité, définissez GRANULARITY_UNIT sur day, puis spécifiez une ou plusieurs régions dans le champ HOLIDAY_REGIONS. Par défaut, la modélisation des effets des jours fériés est désactivée.

    Les valeurs acceptées sont les suivantes :

    • GLOBAL : détecte les jours fériés dans toutes les régions du monde.
    • NA : détecte les jours fériés en Amérique du Nord
    • JAPAC : détecte les jours fériés au Japon et en Asie-Pacifique
    • EMEA : détecte les jours fériés en Europe, au Moyen-Orient et en Afrique
    • LAC : détecte les jours fériés en Amérique latine et dans les Caraïbes
    • Codes pays ISO 3166-1 : détecte les jours fériés des pays individuels.
  • FORECAST_HORIZON : taille de l'horizon de prévision, spécifiée en unités de précision. L'horizon des prévisions correspond à la période pour laquelle le modèle doit prévoir les résultats. En savoir plus
  • CONTEXT_WINDOW : nombre d'unités de précision que le modèle doit examiner pour les inclure au moment de l'entraînement. En savoir plus
  • OPTIMIZATION_OBJECTIVE : obligatoire uniquement si vous ne souhaitez pas atteindre l'objectif d'optimisation par défaut pour votre type de prédiction. En savoir plus
  • TIME_SERIES_ATTRIBUTE_COL : nom ou noms des colonnes qui sont des attributs de série temporelle. En savoir plus
  • AVAILABLE_AT_FORECAST_COL : nom ou noms des colonnes covariées dont la valeur est connue au moment de la prévision. En savoir plus
  • UNAVAILABLE_AT_FORECAST_COL : nom ou noms des colonnes covariées dont la valeur est inconnue au moment de la prévision. En savoir plus
  • TRANSFORMATION_TYPE : le type de transformation est fourni pour chaque colonne utilisée pour entraîner le modèle. En savoir plus
  • COLUMN_NAME : nom de la colonne avec le type de transformation spécifié. Chaque colonne utilisée pour entraîner le modèle doit être spécifiée.
  • MODEL_DISPLAY_NAME : nom à afficher du modèle nouvellement entraîné.
  • DATASET_ID : ID de l'ensemble de données d'entraînement.
  • Vous pouvez fournir un objet Split pour contrôler votre répartition des données. Pour en savoir plus sur le contrôle de la répartition des données, consultez la section Contrôler la répartition des données à l'aide de REST.
  • Vous pouvez fournir un objet windowConfig pour configurer une période de prévision. Pour en savoir plus, consultez la section Configurer l'intervalle de prévision à l'aide de REST.
  • PROJECT_NUMBER : numéro de votre projet

Méthode HTTP et URL :

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Corps JSON de la requête :

{
    "displayName": "TRAINING_PIPELINE_DISPLAY_NAME",
    "trainingTaskDefinition": "TRAINING_TASK_DEFINITION",
    "trainingTaskInputs": {
        "targetColumn": "TARGET_COLUMN",
        "timeColumn": "TIME_COLUMN",
        "timeSeriesIdentifierColumn": "TIME_SERIES_IDENTIFIER_COLUMN",
        "weightColumn": "WEIGHT_COLUMN",
        "trainBudgetMilliNodeHours": TRAINING_BUDGET,
        "dataGranularity": {"unit": "GRANULARITY_UNIT", "quantity": GRANULARITY_QUANTITY},
        "hierarchyConfig": {"groupColumns": GROUP_COLUMNS, "groupTotalWeight": GROUP_TOTAL_WEIGHT, "temporalTotalWeight": TEMPORAL_TOTAL_WEIGHT, "groupTemporalTotalWeight": GROUP_TEMPORAL_TOTAL_WEIGHT}
        "holidayRegions" : ["HOLIDAY_REGIONS_1", "HOLIDAY_REGIONS_2", ...]
        "forecast_horizon": FORECAST_HORIZON,
        "context_window": CONTEXT_WINDOW,
        "optimizationObjective": "OPTIMIZATION_OBJECTIVE",
        "time_series_attribute_columns": ["TIME_SERIES_ATTRIBUTE_COL_1", "TIME_SERIES_ATTRIBUTE_COL_2", ...]
        "available_at_forecast_columns": ["AVAILABLE_AT_FORECAST_COL_1", "AVAILABLE_AT_FORECAST_COL_2", ...]
        "unavailable_at_forecast_columns": ["UNAVAILABLE_AT_FORECAST_COL_1", "UNAVAILABLE_AT_FORECAST_COL_2", ...]
        "transformations": [
            {"TRANSFORMATION_TYPE_1":  {"column_name" : "COLUMN_NAME_1"} },
            {"TRANSFORMATION_TYPE_2":  {"column_name" : "COLUMN_NAME_2"} },
            ...
    },
    "modelToUpload": {"displayName": "MODEL_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
    }
}

Pour envoyer votre requête, développez l'une des options suivantes :

Vous devriez recevoir une réponse JSON de ce type :

{
  "name": "projects/PROJECT_NUMBER/locations/LOCATION/trainingPipelines/TRAINING_PIPELINE_ID",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-08-18T01:22:57.479336Z",
  "updateTime": "2020-08-18T01:22:57.479336Z"
}

Python

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Python.

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

def create_training_pipeline_tabular_forecasting_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    model_display_name: str,
    target_column: str,
    time_series_identifier_column: str,
    time_column: str,
    time_series_attribute_columns: str,
    unavailable_at_forecast: str,
    available_at_forecast: str,
    forecast_horizon: int,
    location: str = "us-central1",
    api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": api_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform.gapic.PipelineServiceClient(client_options=client_options)
    # set the columns used for training and their data types
    transformations = [
        {"auto": {"column_name": "date"}},
        {"auto": {"column_name": "state_name"}},
        {"auto": {"column_name": "county_fips_code"}},
        {"auto": {"column_name": "confirmed_cases"}},
        {"auto": {"column_name": "deaths"}},
    ]

    data_granularity = {"unit": "day", "quantity": 1}

    # the inputs should be formatted according to the training_task_definition yaml file
    training_task_inputs_dict = {
        # required inputs
        "targetColumn": target_column,
        "timeSeriesIdentifierColumn": time_series_identifier_column,
        "timeColumn": time_column,
        "transformations": transformations,
        "dataGranularity": data_granularity,
        "optimizationObjective": "minimize-rmse",
        "trainBudgetMilliNodeHours": 8000,
        "timeSeriesAttributeColumns": time_series_attribute_columns,
        "unavailableAtForecast": unavailable_at_forecast,
        "availableAtForecast": available_at_forecast,
        "forecastHorizon": forecast_horizon,
    }

    training_task_inputs = json_format.ParseDict(training_task_inputs_dict, Value())

    training_pipeline = {
        "display_name": display_name,
        "training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_forecasting_1.0.0.yaml",
        "training_task_inputs": training_task_inputs,
        "input_data_config": {
            "dataset_id": dataset_id,
            "fraction_split": {
                "training_fraction": 0.8,
                "validation_fraction": 0.1,
                "test_fraction": 0.1,
            },
        },
        "model_to_upload": {"display_name": model_display_name},
    }
    parent = f"projects/{project}/locations/{location}"
    response = client.create_training_pipeline(
        parent=parent, training_pipeline=training_pipeline
    )
    print("response:", response)

Régression

Sélectionnez un onglet pour votre langage ou environnement :

API REST et ligne de commande

Vous utilisez la commande trainingPipelines.create pour entraîner un modèle.

Entraîner le modèle

Avant d'utiliser les données de requête ci-dessous, effectuez les remplacements suivants :

  • LOCATION : votre région.
  • PROJECT : ID de votre projet
  • TRAININGPIPELINE_DISPLAY_NAME : nom à afficher du pipeline d'entraînement créé pour cette opération.
  • TARGET_COLUMN : colonne (valeur) que le modèle doit prédire.
  • WEIGHT_COLUMN (facultatif) : colonne de pondération. En savoir plus
  • TRAINING_BUDGET : durée maximale pendant laquelle le modèle doit être entraîné, en milli-nœuds-heure (1 000 milli-nœuds-heure correspondent à un nœud-heure).
  • OPTIMIZATION_OBJECTIVE : obligatoire uniquement si vous ne souhaitez pas atteindre l'objectif d'optimisation par défaut pour votre type de prédiction. En savoir plus
  • TRANSFORMATION_TYPE : le type de transformation est fourni pour chaque colonne utilisée pour entraîner le modèle. En savoir plus
  • COLUMN_NAME : nom de la colonne avec le type de transformation spécifié. Chaque colonne utilisée pour entraîner le modèle doit être spécifiée.
  • MODEL_DISPLAY_NAME : nom à afficher du modèle nouvellement entraîné.
  • DATASET_ID : ID de l'ensemble de données d'entraînement.
  • Vous pouvez fournir un objet Split pour contrôler votre répartition des données. Pour en savoir plus sur le contrôle de la répartition des données, consultez la section Contrôler la répartition des données à l'aide de REST.
  • PROJECT_NUMBER : numéro de votre projet

Méthode HTTP et URL :

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Corps JSON de la requête :

{
    "displayName": "TRAININGPIPELINE_DISPLAY_NAME",
    "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
    "trainingTaskInputs": {
        "targetColumn": "TARGET_COLUMN",
        "weightColumn": "WEIGHT_COLUMN",
        "predictionType": "regression",
        "trainBudgetMilliNodeHours": TRAINING_BUDGET,
        "optimizationObjective": "OPTIMIZATION_OBJECTIVE",
        "transformations": [
            {"TRANSFORMATION_TYPE_1":  {"column_name" : "COLUMN_NAME_1"} },
            {"TRANSFORMATION_TYPE_2":  {"column_name" : "COLUMN_NAME_2"} },
            ...
    },
    "modelToUpload": {"displayName": "MODEL_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
    }
}

Pour envoyer votre requête, développez l'une des options suivantes :

Vous devriez recevoir une réponse JSON de ce type :

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/4567",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-08-18T01:22:57.479336Z",
  "updateTime": "2020-08-18T01:22:57.479336Z"
}

Java

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Java.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.TimestampTransformation;
import com.google.rpc.Status;
import java.io.IOException;
import java.util.ArrayList;

public class CreateTrainingPipelineTabularRegressionSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "YOUR_PROJECT_ID";
    String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
    String datasetId = "YOUR_DATASET_ID";
    String targetColumn = "TARGET_COLUMN";
    createTrainingPipelineTableRegression(project, modelDisplayName, datasetId, targetColumn);
  }

  static void createTrainingPipelineTableRegression(
      String project, String modelDisplayName, String datasetId, String targetColumn)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";

      // Set the columns used for training and their data types
      ArrayList<Transformation> tranformations = new ArrayList<>();
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("STRING_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("INTEGER_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_REPEATED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("NUMERIC_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("BOOLEAN_2unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setTimestamp(
                  TimestampTransformation.newBuilder()
                      .setColumnName("TIMESTAMP_1unique_NULLABLE")
                      .setInvalidValuesAllowed(true))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("DATE_1unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("TIME_1unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setTimestamp(
                  TimestampTransformation.newBuilder()
                      .setColumnName("DATETIME_1unique_NULLABLE")
                      .setInvalidValuesAllowed(true))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.STRING_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE"))
              .build());

      AutoMlTablesInputs trainingTaskInputs =
          AutoMlTablesInputs.newBuilder()
              .addAllTransformations(tranformations)
              .setTargetColumn(targetColumn)
              .setPredictionType("regression")
              .setTrainBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              // supported regression optimisation objectives: minimize-rmse,
              // minimize-mae, minimize-rmsle
              .setOptimizationObjective("minimize-rmse")
              .build();

      FractionSplit fractionSplit =
          FractionSplit.newBuilder()
              .setTrainingFraction(0.8)
              .setValidationFraction(0.1)
              .setTestFraction(0.1)
              .build();

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder()
              .setDatasetId(datasetId)
              .setFractionSplit(fractionSplit)
              .build();
      Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();

      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(modelDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(modelToUpload)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Tabular Regression Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
      System.out.format(
          "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());

      System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
      System.out.format(
          "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());

      FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format(
          "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());
      System.out.format(
          "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());

      FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList().toString());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList().toString());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList().toString());

      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());
      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();

      System.out.println("\tPredict Schemata");
      System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format(
          "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format(
          "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (Model.ExportFormat supportedExportFormat :
          modelResponse.getSupportedExportFormatsList()) {
        System.out.println("\tSupported Export Format");
        System.out.format("\t\tId: %s\n", supportedExportFormat.getId());
      }
      ModelContainerSpec containerSpec = modelResponse.getContainerSpec();

      System.out.println("\tContainer Spec");
      System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());
      System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());
      System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());
      System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());
      System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());

      for (EnvVar envVar : containerSpec.getEnvList()) {
        System.out.println("\t\tEnv");
        System.out.format("\t\t\tName: %s\n", envVar.getName());
        System.out.format("\t\t\tValue: %s\n", envVar.getValue());
      }

      for (Port port : containerSpec.getPortsList()) {
        System.out.println("\t\tPort");
        System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("\tDeployed Model");
        System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Node.js.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const targetColumn = 'YOUR_TARGET_COLUMN';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;
// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineTablesRegression() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const transformations = [
    {auto: {column_name: 'STRING_5000unique_NULLABLE'}},
    {auto: {column_name: 'INTEGER_5000unique_NULLABLE'}},
    {auto: {column_name: 'FLOAT_5000unique_NULLABLE'}},
    {auto: {column_name: 'FLOAT_5000unique_REPEATED'}},
    {auto: {column_name: 'NUMERIC_5000unique_NULLABLE'}},
    {auto: {column_name: 'BOOLEAN_2unique_NULLABLE'}},
    {
      timestamp: {
        column_name: 'TIMESTAMP_1unique_NULLABLE',
        invalid_values_allowed: true,
      },
    },
    {auto: {column_name: 'DATE_1unique_NULLABLE'}},
    {auto: {column_name: 'TIME_1unique_NULLABLE'}},
    {
      timestamp: {
        column_name: 'DATETIME_1unique_NULLABLE',
        invalid_values_allowed: true,
      },
    },
    {auto: {column_name: 'STRUCT_NULLABLE.STRING_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_REPEATED'}},
    {auto: {column_name: 'STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.BOOLEAN_2unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE'}},
  ];

  const trainingTaskInputsObj = new definition.AutoMlTablesInputs({
    transformations,
    targetColumn,
    predictionType: 'regression',
    trainBudgetMilliNodeHours: 8000,
    disableEarlyStopping: false,
    optimizationObjective: 'minimize-rmse',
  });
  const trainingTaskInputs = trainingTaskInputsObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {
    datasetId: datasetId,
    fractionSplit: {
      trainingFraction: 0.8,
      validationFraction: 0.1,
      testFraction: 0.1,
    },
  };
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] = await pipelineServiceClient.createTrainingPipeline(
    request
  );

  console.log('Create training pipeline tabular regression response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineTablesRegression();

Python

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Python.

def create_training_pipeline_tabular_regression_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    location: str = "us-central1",
    model_display_name: str = "my_model",
    target_column: str = "target_column",
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    tabular_regression_job = aiplatform.AutoMLTabularTrainingJob(
        display_name=display_name, optimization_prediction_type="regression"
    )

    my_tabular_dataset = aiplatform.TabularDataset(dataset_name=dataset_id)

    model = tabular_regression_job.run(
        dataset=my_tabular_dataset,
        target_column=target_column,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        model_display_name=model_display_name,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Texte

Sélectionnez un objectif de type de données textuelles.

Classification

Sélectionnez un onglet pour votre langage ou environnement :

API REST et ligne de commande

Vous utilisez la commande trainingPipelines.create pour entraîner un modèle.

Créez un objet TrainingPipeline pour entraîner un modèle.

Avant d'utiliser les données de requête ci-dessous, effectuez les remplacements suivants :

  • LOCATION : région dans laquelle le modèle sera créé, par exemple us-central1.
  • PROJECT : ID de votre projet
  • MODEL_DISPLAY_NAME : nom du modèle tel qu'il apparaît dans l'interface utilisateur.
  • MULTI-LABEL : valeur booléenne indiquant si Vertex AI entraîne un modèle multi-étiquette. la valeur par défaut est false (modèle à étiquette unique).
  • DATASET_ID : ID de l'ensemble de données.
  • PROJECT_NUMBER : numéro de votre projet

Méthode HTTP et URL :

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Corps JSON de la requête :

{
  "displayName": "MODEL_DISPLAY_NAME",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml",
  "trainingTaskInputs": {
    "multiLabel": MULTI-LABEL
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAY_NAME"
  },
  "inputDataConfig": {
    "datasetId": "DATASET_ID"
  }
}

Pour envoyer votre requête, développez l'une des options suivantes :

Vous devriez recevoir une réponse JSON de ce type :

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/PIPELINE_ID",
  "displayName": "MODEL_DISPLAY_NAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID"
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml",
  "trainingTaskInputs": {
    "multiLabel": MULTI-LABEL
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAY_NAME"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-04-18T01:22:57.479336Z",
  "updateTime": "2020-04-18T01:22:57.479336Z"
}

Java

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Java.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlTextClassificationInputs;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineTextClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";

    createTrainingPipelineTextClassificationSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineTextClassificationSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_text_classification_1.0.0.yaml";

      LocationName locationName = LocationName.of(project, location);

      AutoMlTextClassificationInputs trainingTaskInputs =
          AutoMlTextClassificationInputs.newBuilder().setMultiLabel(false).build();

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Text Classification Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());

      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("\t\tPredict Schemata");
      System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format(
          "\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format(
          "\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("\t\tSupported Export Format");
        System.out.format("\t\t\tId: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("\t\tContainer Spec");
      System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList());
      System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList());
      System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("\t\t\tEnv");
        System.out.format("\t\t\t\tName: %s\n", envVar.getName());
        System.out.format("\t\t\t\tValue: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("\t\t\tPort");
        System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("\t\tDeployed Model");
        System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Node.js.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineTextClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const trainingTaskInputObj = new definition.AutoMlTextClassificationInputs({
    multiLabel: false,
  });
  const trainingTaskInputs = trainingTaskInputObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId: datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_classification_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] = await pipelineServiceClient.createTrainingPipeline(
    request
  );

  console.log('Create training pipeline text classification response :');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineTextClassification();

Python

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Python.

def create_training_pipeline_text_classification_sample(
    project: str,
    location: str,
    display_name: str,
    dataset_id: str,
    model_display_name: Optional[str] = None,
    multi_label: bool = False,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    job = aiplatform.AutoMLTextTrainingJob(
        display_name=display_name,
        prediction_type="classification",
        multi_label=multi_label,
    )

    text_dataset = aiplatform.TextDataset(dataset_id)

    model = job.run(
        dataset=text_dataset,
        model_display_name=model_display_name,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Extraction d'entités

Sélectionnez un onglet pour votre langage ou environnement :

API REST et ligne de commande

Vous utilisez la commande trainingPipelines.create pour entraîner un modèle.

Créez un objet TrainingPipeline pour entraîner un modèle.

Avant d'utiliser les données de requête ci-dessous, effectuez les remplacements suivants :

  • LOCATION : région dans laquelle le modèle sera créé, par exemple us-central1.
  • PROJECT : ID de votre projet
  • MODEL_DISPLAY_NAME : nom du modèle tel qu'il apparaît dans l'interface utilisateur.
  • DATASET_ID : ID de l'ensemble de données.
  • PROJECT_NUMBER : numéro de votre projet

Méthode HTTP et URL :

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Corps JSON de la requête :

{
  "displayName": "MODEL_DISPLAY_NAME",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_extraction_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "MODEL_DISPLAY_NAME"
  },
  "inputDataConfig": {
    "datasetId": "DATASET_ID"
  }
}

Pour envoyer votre requête, développez l'une des options suivantes :

Vous devriez recevoir une réponse JSON de ce type :

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/PIPELINE_ID",
  "displayName": "MODEL_DISPLAY_NAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID"
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_extraction_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "MODEL_DISPLAY_NAME"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-04-18T01:22:57.479336Z",
  "updateTime": "2020-04-18T01:22:57.479336Z"
}

Java

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Java.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineTextEntityExtractionSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";

    createTrainingPipelineTextEntityExtractionSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineTextEntityExtractionSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_text_extraction_1.0.0.yaml";

      LocationName locationName = LocationName.of(project, location);

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.EMPTY_VALUE)
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Text Entity Extraction Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());

      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("\t\tPredict Schemata");
      System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format(
          "\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format(
          "\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("\t\tSupported Export Format");
        System.out.format("\t\t\tId: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("\t\tContainer Spec");
      System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList());
      System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList());
      System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("\t\t\tEnv");
        System.out.format("\t\t\t\tName: %s\n", envVar.getName());
        System.out.format("\t\t\t\tValue: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("\t\t\tPort");
        System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("\t\tDeployed Model");
        System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Node.js.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineTextEntityExtraction() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const trainingTaskInputObj = new definition.AutoMlTextExtractionInputs({});
  const trainingTaskInputs = trainingTaskInputObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId: datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_extraction_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] = await pipelineServiceClient.createTrainingPipeline(
    request
  );

  console.log('Create training pipeline text entity extraction response :');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineTextEntityExtraction();

Python

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Python.

def create_training_pipeline_text_entity_extraction_sample(
    project: str,
    location: str,
    display_name: str,
    dataset_id: str,
    model_display_name: Optional[str] = None,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    job = aiplatform.AutoMLTextTrainingJob(
        display_name=display_name, prediction_type="extraction"
    )

    text_dataset = aiplatform.TextDataset(dataset_id)

    model = job.run(
        dataset=text_dataset,
        model_display_name=model_display_name,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Analyse des sentiments

Sélectionnez un onglet pour votre langage ou environnement :

API REST et ligne de commande

Vous utilisez la commande trainingPipelines.create pour entraîner un modèle.

Créez un objet TrainingPipeline pour entraîner un modèle.

Avant d'utiliser les données de requête ci-dessous, effectuez les remplacements suivants :

  • LOCATION : région dans laquelle le modèle sera créé, par exemple us-central1.
  • PROJECT : ID de votre projet
  • MODEL_DISPLAY_NAME : nom du modèle tel qu'il apparaît dans l'interface utilisateur.
  • SENTIMENT_MAX : score de sentiment maximal dans votre ensemble de données d'entraînement.
  • DATASET_ID : ID de l'ensemble de données.
  • PROJECT_NUMBER : numéro de votre projet

Méthode HTTP et URL :

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Corps JSON de la requête :

{
  "displayName": "MODEL_DISPLAY_NAME",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_sentiment_1.0.0.yaml",
  "trainingTaskInputs": {
    "sentimentMax": SENTIMENT_MAX
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAY_NAME"
  },
  "inputDataConfig": {
    "datasetId": "DATASET_ID"
  }
}

Pour envoyer votre requête, développez l'une des options suivantes :

Vous devriez recevoir une réponse JSON de ce type :

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/PIPELINE_ID",
  "displayName": "MODEL_DISPLAY_NAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID"
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_sentiment_1.0.0.yaml",
  "trainingTaskInputs": {
    "sentimentMax": SENTIMENT_MAX
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAY_NAME"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-04-18T01:22:57.479336Z",
  "updateTime": "2020-04-18T01:22:57.479336Z"
}

Java

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Java.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTextSentimentInputs;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineTextSentimentAnalysisSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";

    createTrainingPipelineTextSentimentAnalysisSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineTextSentimentAnalysisSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_text_sentiment_1.0.0.yaml";

      LocationName locationName = LocationName.of(project, location);

      AutoMlTextSentimentInputs trainingTaskInputs =
          AutoMlTextSentimentInputs.newBuilder()
              // Sentiment max must be between 1 and 10 inclusive.
              // Higher value means positive sentiment.
              .setSentimentMax(4)
              .build();

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Text Sentiment Analysis Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "\tTraining Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("\t\tAnnotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());

      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMetadata: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLabels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("\t\tPredict Schemata");
      System.out.format("\t\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format(
          "\t\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format(
          "\t\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("\t\tSupported Export Format");
        System.out.format("\t\t\tId: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("\t\tContainer Spec");
      System.out.format("\t\t\tImage Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("\t\t\tCommand: %s\n", modelContainerSpec.getCommandList());
      System.out.format("\t\t\tArgs: %s\n", modelContainerSpec.getArgsList());
      System.out.format("\t\t\tPredict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("\t\t\tHealth Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("\t\t\tEnv");
        System.out.format("\t\t\t\tName: %s\n", envVar.getName());
        System.out.format("\t\t\t\tValue: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("\t\t\tPort");
        System.out.format("\t\t\t\tContainer Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("\t\tDeployed Model");
        System.out.format("\t\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("\t\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Node.js.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineTextSentimentAnalysis() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const trainingTaskInputObj = new definition.AutoMlTextSentimentInputs({
    sentimentMax: 4,
  });
  const trainingTaskInputs = trainingTaskInputObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId: datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_text_sentiment_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] = await pipelineServiceClient.createTrainingPipeline(
    request
  );

  console.log('Create training pipeline text sentiment analysis response :');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineTextSentimentAnalysis();

Python

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Python.

def create_training_pipeline_text_sentiment_analysis_sample(
    project: str,
    location: str,
    display_name: str,
    dataset_id: str,
    model_display_name: Optional[str] = None,
    sentiment_max: int = 10,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    job = aiplatform.AutoMLTextTrainingJob(
        display_name=display_name,
        prediction_type="sentiment",
        sentiment_max=sentiment_max,
    )

    text_dataset = aiplatform.TextDataset(dataset_id)

    model = job.run(
        dataset=text_dataset,
        model_display_name=model_display_name,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Vidéo

Sélectionnez l'onglet correspondant à votre objectif :

Reconnaissance des actions

Sélectionnez l'onglet correspondant à votre langage ou à votre environnement :

API REST et ligne de commande

Avant d'utiliser les données de requête ci-dessous, effectuez les remplacements suivants :

  • PROJECT : ID de votre projet
  • LOCATION : région d'emplacement de l'ensemble de données et de création du modèle. Exemple :us-central1
  • TRAINING_PIPELINE_DISPLAY_NAME : valeur obligatoire. Nom à afficher pour le TrainingPipeline.
  • DATASET_ID : ID de l'ensemble de données d'entraînement.
  • TRAINING_FRACTION, TEST_FRACTION : l'objet fractionSplit est facultatif, il sert à contrôler la répartition des données. Pour en savoir plus sur le contrôle de la répartition des données, consultez la page À propos de la répartition des données pour les modèles AutoML. Exemple :
    • {"trainingFraction": "0.8","validationFraction": "0","testFraction": "0.2"}
  • MODEL_DISPLAY_NAME : nom à afficher du modèle entraîné.
  • MODEL_DESCRIPTION : description du modèle.
  • MODEL_LABELS : tout ensemble de paires clé/valeur pour organiser vos modèles. Exemple :
    • "env" : "prod"
    • "tier" : "backend"
  • EDGE_MODEL_TYPE :
    • MOBILE_VERSATILE_1 : usage général
  • PROJECT_NUMBER : numéro de votre projet

Méthode HTTP et URL :

POST https://LOCATION-aiplatform.googleapis.com/beta1/projects/PROJECT/locations/LOCATION/trainingPipelines

Corps JSON de la requête :

{
  "displayName": "TRAINING_PIPELINE_DISPLAY_NAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID",
    "fractionSplit": {
      "trainingFraction": "TRAINING_FRACTION",
      "validationFraction": "0",
      "testFraction": "TEST_FRACTION"
    }
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAY_NAME",
    "description": "MODEL_DESCRIPTION",
    "labels": {
      "KEY": "VALUE"
    }
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_object_tracking_1.0.0.yaml",
  "trainingTaskInputs": {
    "modelType": ["EDGE_MODEL_TYPE"],
  }
}

Pour envoyer votre requête, choisissez l'une des options suivantes :

curl

Enregistrez le corps de la requête dans un fichier nommé request.json, puis exécutez la commande suivante :

curl -X POST \
-H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION-aiplatform.googleapis.com/beta1/projects/PROJECT/locations/LOCATION/trainingPipelines"

PowerShell

Enregistrez le corps de la requête dans un fichier nommé request.json, puis exécutez la commande suivante :

$cred = gcloud auth application-default print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION-aiplatform.googleapis.com/beta1/projects/PROJECT/locations/LOCATION/trainingPipelines" | Select-Object -Expand Content

La réponse contient des informations sur les spécifications, ainsi que sur TRAININGPIPELINE_ID.

Java

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Java.

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoActionRecognitionInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoActionRecognitionInputs.ModelType;
import java.io.IOException;

public class CreateTrainingPipelineVideoActionRecognitionSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "PROJECT";
    String displayName = "DISPLAY_NAME";
    String datasetId = "DATASET_ID";
    String modelDisplayName = "MODEL_DISPLAY_NAME";
    createTrainingPipelineVideoActionRecognitionSample(
        project, displayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineVideoActionRecognitionSample(
      String project, String displayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings settings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();
    String location = "us-central1";

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient client = PipelineServiceClient.create(settings)) {
      AutoMlVideoActionRecognitionInputs trainingTaskInputs =
          AutoMlVideoActionRecognitionInputs.newBuilder().setModelType(ModelType.CLOUD).build();

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(displayName)
              .setTrainingTaskDefinition(
                  "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
                      + "automl_video_action_recognition_1.0.0.yaml")
              .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(modelToUpload)
              .build();
      LocationName parent = LocationName.of(project, location);
      TrainingPipeline response = client.createTrainingPipeline(parent, trainingPipeline);
      System.out.format("response: %s\n", response);
      System.out.format("Name: %s\n", response.getName());
    }
  }
}

Python

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Python.

from google.cloud import aiplatform
from google.cloud.aiplatform.gapic.schema import trainingjob

def create_training_pipeline_video_action_recognition_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    model_display_name: str,
    model_type: str,
    location: str = "us-central1",
    api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": api_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform.gapic.PipelineServiceClient(client_options=client_options)
    training_task_inputs = trainingjob.definition.AutoMlVideoActionRecognitionInputs(
        # modelType can be either 'CLOUD' or 'MOBILE_VERSATILE_1'
        model_type=model_type,
    ).to_value()

    training_pipeline = {
        "display_name": display_name,
        "training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_action_recognition_1.0.0.yaml",
        "training_task_inputs": training_task_inputs,
        "input_data_config": {"dataset_id": dataset_id},
        "model_to_upload": {"display_name": model_display_name},
    }
    parent = f"projects/{project}/locations/{location}"
    response = client.create_training_pipeline(
        parent=parent, training_pipeline=training_pipeline
    )
    print("response:", response)

Classification

Sélectionnez l'onglet correspondant à votre langage ou à votre environnement :

API REST et ligne de commande

Avant d'utiliser les données de requête ci-dessous, effectuez les remplacements suivants :

  • LOCATION : région d'emplacement de l'ensemble de données et de stockage. Exemple : us-central1.
  • PROJECT : ID de votre projet
  • MODEL_DISPLAY_NAME : nom à afficher du modèle nouvellement entraîné.
  • DATASET_ID : ID de l'ensemble de données d'entraînement.
  • L'objet filterSplit est facultatif, il sert à contrôler la répartition des données. Pour en savoir plus sur le contrôle de la répartition des données, consultez la section Contrôler la répartition des données à l'aide de REST.
  • PROJECT_NUMBER : numéro de votre projet

Méthode HTTP et URL :

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Corps JSON de la requête :

{
    "displayName": "MODE_DISPLAY_NAME",
    "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_classification_1.0.0.yaml",
    "trainingTaskInputs": {},
    "modelToUpload": {"displayName": "MODE_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
      "filterSplit": {
        "trainingFilter": "labels.ml_use = training",
        "validationFilter": "labels.ml_use = -",
        "testFilter": "labels.ml_use = test"
      }
    }
}

Pour envoyer votre requête, développez l'une des options suivantes :

Vous devriez recevoir une réponse JSON de ce type :

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/2307109646608891904",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_classification_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-04-18T01:22:57.479336Z",
  "updateTime": "2020-04-18T01:22:57.479336Z"
}

Java

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Java.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineVideoClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String videoClassificationDisplayName =
        "YOUR_TRAINING_PIPELINE_VIDEO_CLASSIFICATION_DISPLAY_NAME";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    createTrainingPipelineVideoClassification(
        videoClassificationDisplayName, datasetId, modelDisplayName, project);
  }

  static void createTrainingPipelineVideoClassification(
      String videoClassificationDisplayName,
      String datasetId,
      String modelDisplayName,
      String project)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_video_classification_1.0.0.yaml";

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();

      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(videoClassificationDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.EMPTY_VALUE)
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Video Classification Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
      System.out.format(
          "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
      System.out.format(
          "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfigResponse.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());
      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList().toString());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList().toString());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList().toString());
      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Node.js.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineVideoClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;
  // Values should match the input expected by your model.
  const trainingTaskInputObj = new definition.AutoMlVideoClassificationInputs(
    {}
  );
  const trainingTaskInputs = trainingTaskInputObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId: datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_classification_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] = await pipelineServiceClient.createTrainingPipeline(
    request
  );

  console.log('Create training pipeline video classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineVideoClassification();

Python

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Python.

from google.cloud import aiplatform
from google.cloud.aiplatform.gapic.schema import trainingjob

def create_training_pipeline_video_classification_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    model_display_name: str,
    location: str = "us-central1",
    api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": api_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform.gapic.PipelineServiceClient(client_options=client_options)
    training_task_inputs = (
        trainingjob.definition.AutoMlVideoClassificationInputs().to_value()
    )

    training_pipeline = {
        "display_name": display_name,
        "training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_classification_1.0.0.yaml",
        # Training task inputs are empty for video classification
        "training_task_inputs": training_task_inputs,
        "input_data_config": {"dataset_id": dataset_id},
        "model_to_upload": {"display_name": model_display_name},
    }
    parent = f"projects/{project}/locations/{location}"
    response = client.create_training_pipeline(
        parent=parent, training_pipeline=training_pipeline
    )
    print("response:", response)

Suivi des objets

Sélectionnez l'onglet correspondant à votre langage ou à votre environnement :

API REST et ligne de commande

Avant d'utiliser les données de requête ci-dessous, effectuez les remplacements suivants :

  • LOCATION : région d'emplacement de l'ensemble de données et de stockage. Exemple : us-central1.
  • PROJECT : ID de votre projet
  • MODEL_DISPLAY_NAME : nom à afficher du modèle nouvellement entraîné.
  • DATASET_ID : ID de l'ensemble de données d'entraînement.
  • L'objet filterSplit est facultatif, il sert à contrôler la répartition des données. Pour en savoir plus sur le contrôle de la répartition des données, consultez la section Contrôler la répartition des données à l'aide de REST.
  • PROJECT_NUMBER : numéro de votre projet

Méthode HTTP et URL :

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Corps JSON de la requête :

{
    "displayName": "MODE_DISPLAY_NAME",
    "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_object_tracking_1.0.0.yaml",
    "trainingTaskInputs": {},
    "modelToUpload": {"displayName": "MODE_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
      "filterSplit": {
        "trainingFilter": "labels.ml_use = training",
        "validationFilter": "labels.ml_use = -",
        "testFilter": "labels.ml_use = test"
      }
    }
}

Pour envoyer votre requête, développez l'une des options suivantes :

Vous devriez recevoir une réponse JSON de ce type :

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/2307109646608891904",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_object_tracking_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-04-18T01:22:57.479336Z",
  "updateTime": "2020-04-18T01:22:57.479336Z"
}

Java

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Java.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoObjectTrackingInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoObjectTrackingInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineVideoObjectTrackingSample {

  public static void main(String[] args) throws IOException {
    String trainingPipelineVideoObjectTracking =
        "YOUR_TRAINING_PIPELINE_VIDEO_OBJECT_TRACKING_DISPLAY_NAME";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    createTrainingPipelineVideoObjectTracking(
        trainingPipelineVideoObjectTracking, datasetId, modelDisplayName, project);
  }

  static void createTrainingPipelineVideoObjectTracking(
      String trainingPipelineVideoObjectTracking,
      String datasetId,
      String modelDisplayName,
      String project)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_video_object_tracking_1.0.0.yaml";
      LocationName locationName = LocationName.of(project, location);

      AutoMlVideoObjectTrackingInputs trainingTaskInputs =
          AutoMlVideoObjectTrackingInputs.newBuilder().setModelType(ModelType.CLOUD).build();

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineVideoObjectTracking)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(modelToUpload)
              .build();

      TrainingPipeline createTrainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Video Object Tracking Response");
      System.out.format("Name: %s\n", createTrainingPipelineResponse.getName());
      System.out.format("Display Name: %s\n", createTrainingPipelineResponse.getDisplayName());

      System.out.format(
          "Training Task Definition %s\n",
          createTrainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "Training Task Inputs: %s\n",
          createTrainingPipelineResponse.getTrainingTaskInputs().toString());
      System.out.format(
          "Training Task Metadata: %s\n",
          createTrainingPipelineResponse.getTrainingTaskMetadata().toString());

      System.out.format("State: %s\n", createTrainingPipelineResponse.getState().toString());
      System.out.format(
          "Create Time: %s\n", createTrainingPipelineResponse.getCreateTime().toString());
      System.out.format("StartTime %s\n", createTrainingPipelineResponse.getStartTime().toString());
      System.out.format("End Time: %s\n", createTrainingPipelineResponse.getEndTime().toString());
      System.out.format(
          "Update Time: %s\n", createTrainingPipelineResponse.getUpdateTime().toString());
      System.out.format("Labels: %s\n", createTrainingPipelineResponse.getLabelsMap().toString());

      InputDataConfig inputDataConfigResponse = createTrainingPipelineResponse.getInputDataConfig();
      System.out.println("Input Data config");
      System.out.format("Dataset Id: %s\n", inputDataConfigResponse.getDatasetId());
      System.out.format("Annotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfigResponse.getFractionSplit();
      System.out.println("Fraction split");
      System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
      System.out.println("Filter Split");
      System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
      System.out.println("Predefined Split");
      System.out.format("Key: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
      System.out.println("Timestamp Split");
      System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("Key: %s\n", timestampSplit.getKey());

      Model modelResponse = createTrainingPipelineResponse.getModelToUpload();
      System.out.println("Model To Upload");
      System.out.format("Name: %s\n", modelResponse.getName());
      System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
      System.out.format("Description: %s\n", modelResponse.getDescription());
      System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("Metadata: %s\n", modelResponse.getMetadata());

      System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "Supported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList().toString());
      System.out.format(
          "Supported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList().toString());
      System.out.format(
          "Supported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList().toString());

      System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
      System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("Labels: %s\n", modelResponse.getLabelsMap());

      Status status = createTrainingPipelineResponse.getError();
      System.out.println("Error");
      System.out.format("Code: %s\n", status.getCode());
      System.out.format("Message: %s\n", status.getMessage());
    }
  }
}

Node.js

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Node.js.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;
const ModelType = definition.AutoMlVideoObjectTrackingInputs.ModelType;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineVideoObjectTracking() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const trainingTaskInputsObj =
    new definition.AutoMlVideoObjectTrackingInputs({
      modelType: ModelType.CLOUD,
    });
  const trainingTaskInputs = trainingTaskInputsObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId: datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_object_tracking_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] = await pipelineServiceClient.createTrainingPipeline(
    request
  );

  console.log('Create training pipeline video object tracking response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineVideoObjectTracking();

Python

Pour savoir comment installer et utiliser la bibliothèque cliente pour Vertex AI, consultez Bibliothèques clientes Vertex AI. Pour en savoir plus, consultez la documentation de référence de l'API Vertex AI Python.

from google.cloud import aiplatform
from google.cloud.aiplatform.gapic.schema import trainingjob

def create_training_pipeline_video_object_tracking_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    model_display_name: str,
    location: str = "us-central1",
    api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": api_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform.gapic.PipelineServiceClient(client_options=client_options)
    training_task_inputs = trainingjob.definition.AutoMlVideoObjectTrackingInputs(
        model_type="CLOUD",
    ).to_value()

    training_pipeline = {
        "display_name": display_name,
        "training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_object_tracking_1.0.0.yaml",
        "training_task_inputs": training_task_inputs,
        "input_data_config": {"dataset_id": dataset_id},
        "model_to_upload": {"display_name": model_display_name},
    }
    parent = f"projects/{project}/locations/{location}"
    response = client.create_training_pipeline(
        parent=parent, training_pipeline=training_pipeline
    )
    print("response:", response)

Contrôler la répartition des données à l'aide de REST

Vous pouvez contrôler la manière dont vos données d'entraînement sont réparties entre les ensembles d'entraînement, de validation et de test. Lorsque vous utilisez l'API Vertex AI, déterminez la répartition des données à l'aide de l'objet Split. L'objet Split peut être inclus dans l'objet InputConfig sous la forme de plusieurs types d'objets, chacun offrant une manière différente de répartir les données d'entraînement. Vous ne pouvez sélectionner qu'une seule méthode.

Les méthodes que vous pouvez utiliser pour répartir vos données dépendent du type de données :

Image, texte, vidéo

  • FractionSplit :
    • TRAINING_FRACTION : fraction des données d'entraînement à utiliser pour l'ensemble d'entraînement.
    • VALIDATION_FRACTION : fraction des données d'entraînement à utiliser pour l'ensemble de validation. Non utilisé pour les données vidéo.
    • TEST_FRACTION : fraction des données d'entraînement à utiliser pour l'ensemble de test.

    Si l'une des fractions est spécifiée, elles doivent toutes être spécifiées. La somme des fractions doit être égale à 1,0. Les valeurs par défaut des fractions diffèrent selon le type de données. En savoir plus

    "fractionSplit": {
      "trainingFraction": TRAINING_FRACTION,
      "validationFraction": VALIDATION_FRACTION,
      "testFraction": TEST_FRACTION
    },
    
  • FilterSplit :
    • TRAINING_FILTER : les éléments de données correspondant à ce filtre sont utilisés pour l'ensemble d'entraînement.
    • VALIDATION_FILTER : les éléments de données correspondant à ce filtre sont utilisés pour l'ensemble de validation. La valeur doit être "-" pour les données vidéo.
    • TEST_FILTER : les éléments de données correspondant à ce filtre sont utilisés pour l'ensemble de test.

    Ces filtres peuvent être utilisés avec l'étiquette ml_use ou avec les étiquettes que vous appliquez à vos données. Découvrez comment filtrer vos données à l'aide de l'étiquette ml-use et d'autres étiquettes.

    L'exemple suivant montre comment utiliser l'objet filterSplit avec l'étiquette ml_use, avec l'ensemble de validation inclus :

    "filterSplit": {
      "trainingFilter": "labels.aiplatform.googleapis.com/ml_use=training",
      "validationFilter": "labels.aiplatform.googleapis.com/ml_use=validation",
      "testFilter": "labels.aiplatform.googleapis.com/ml_use=test"
    }
    

Tabulaire

  • FractionSplit :
    • TRAINING_FRACTION : fraction des données d'entraînement à utiliser pour l'ensemble d'entraînement.
    • VALIDATION_FRACTION : fraction des données d'entraînement à utiliser pour l'ensemble de validation. Non utilisé pour les données vidéo.
    • TEST_FRACTION : fraction des données d'entraînement à utiliser pour l'ensemble de test.

    Si l'une des fractions est spécifiée, elles doivent toutes être spécifiées. La somme des fractions doit être égale à 1,0. Les valeurs par défaut des fractions diffèrent selon le type de données. En savoir plus

    "fractionSplit": {
      "trainingFraction": TRAINING_FRACTION,
      "validationFraction": VALIDATION_FRACTION,
      "testFraction": TEST_FRACTION
    },
    

    L'objet fractionSplit n'est pas compatible avec les modèles de prévision.

  • PredefinedSplit :
    • DATA_SPLIT_COLUMN : colonne contenant les valeurs de répartition des données (TRAIN, VALIDATION, TEST).

    Spécifiez manuellement la répartition des données pour chaque ligne à l'aide d'une colonne fractionnée. En savoir plus

    "predefinedSplit": {
      "key": DATA_SPLIT_COLUMN
    },
    
  • TimestampSplit :
    • TRAINING_FRACTION : pourcentage des données d'entraînement à utiliser pour l'ensemble d'entraînement. La valeur par défaut est 0,80.
    • VALIDATION_FRACTION : pourcentage des données d'entraînement à utiliser pour l'ensemble de validation. La valeur par défaut est 0,10.
    • TEST_FRACTION : pourcentage des données d'entraînement à utiliser pour l'ensemble de test. La valeur par défaut est 0,10.
    • TIME_COLUMN : colonne contenant les horodatages.

    Si l'une des fractions est spécifiée, elles doivent toutes être spécifiées. La somme des fractions doit être égale à 1,0. En savoir plus

    L'objet TimestampSplit n'est pas compatible avec les modèles de prévision.

    "timestampSplit": {
      "trainingFraction": TRAINING_FRACTION,
      "validationFraction": VALIDATION_FRACTION,
      "testFraction": TEST_FRACTION
      "key": TIME_COLUMN
    }
    

Étape suivante