動画オブジェクト トラッキング モデルをトレーニングする

このページでは、Google Cloud コンソールまたは Vertex AI API を使用して、動画データセットから AutoML オブジェクト トラッキング モデルをトレーニングする方法について説明します。

AutoML モデルをトレーニングする

Google Cloud コンソール

  1. Google Cloud コンソールの [Vertex AI] セクションで、[データセット] ページに移動します。

    [データセット] ページに移動

  2. モデルのトレーニングに使用するデータセットの名前をクリックして、詳細ページを開きます。

  3. [新しいモデルのトレーニング] をクリックします。

  4. 新しいモデルの表示名を入力します。

  5. トレーニング データの分割方法を手動で設定する場合は、[ADVANCED OPTIONS] を開き、データ分割オプションを選択します詳細

  6. [続行] をクリックします。

  7. モデルのトレーニング方法を選択します。

    • AutoML は幅広いユースケースに適しています。
    • Seq2seq+ はテストに適しています。このアーキテクチャはシンプルで、検索スペースが小さいため、AutoML よりも速く収束する可能性が高くなります。Google でのテストによると、Seq2Seq+ は少ない時間予算で十分に機能し、サイズが 1 GB 未満のデータセットで高いパフォーマンスを発揮します。
    [続行] をクリックします。

  8. [トレーニングを開始] をクリックします。

    データのサイズ、複雑さ、トレーニング予算(指定された場合)に応じて、モデルのトレーニングに何時間もかかることがあります。このタブを閉じて、後で戻ることもできます。モデルのトレーニングが完了すると、メールが送られてきます。

    トレーニング開始から数分後には、モデルのプロパティ情報からトレーニング ノード時間の見積もりを確認できます。トレーニングをキャンセルした場合、現在のプロダクトに料金はかかりません。

API

お使いの言語または環境に応じて、以下のタブを選択してください。

REST

リクエストのデータを使用する前に、次のように置き換えます。

  • LOCATION: データセットが配置され、モデルが格納されるリージョン。例: us-central1
  • PROJECT: 実際のプロジェクト ID
  • MODEL_DISPLAY_NAME: 新しくトレーニングされたモデルの表示名。
  • DATASET_ID: トレーニング データセットの ID。
  • filterSplit オブジェクトは省略可です。データ分割の制御に使用します。データ分割の制御については、REST を使用したデータ分割の制御をご覧ください。
  • PROJECT_NUMBER: プロジェクトに自動生成されたプロジェクト番号

HTTP メソッドと URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

リクエストの本文(JSON):

{
    "displayName": "MODE_DISPLAY_NAME",
    "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_object_tracking_1.0.0.yaml",
    "trainingTaskInputs": {},
    "modelToUpload": {"displayName": "MODE_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
      "filterSplit": {
        "trainingFilter": "labels.ml_use = training",
        "validationFilter": "labels.ml_use = -",
        "testFilter": "labels.ml_use = test"
      }
    }
}

リクエストを送信するには、次のいずれかのオプションを展開します。

次のような JSON レスポンスが返されます。

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/2307109646608891904",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_object_tracking_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-04-18T01:22:57.479336Z",
  "updateTime": "2020-04-18T01:22:57.479336Z"
}

Java

このサンプルを試す前に、Vertex AI クイックスタート: クライアント ライブラリの使用にある Java の設定手順を完了してください。詳細については、Vertex AI Java API のリファレンス ドキュメントをご覧ください。

Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoObjectTrackingInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlVideoObjectTrackingInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineVideoObjectTrackingSample {

  public static void main(String[] args) throws IOException {
    String trainingPipelineVideoObjectTracking =
        "YOUR_TRAINING_PIPELINE_VIDEO_OBJECT_TRACKING_DISPLAY_NAME";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    createTrainingPipelineVideoObjectTracking(
        trainingPipelineVideoObjectTracking, datasetId, modelDisplayName, project);
  }

  static void createTrainingPipelineVideoObjectTracking(
      String trainingPipelineVideoObjectTracking,
      String datasetId,
      String modelDisplayName,
      String project)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_video_object_tracking_1.0.0.yaml";
      LocationName locationName = LocationName.of(project, location);

      AutoMlVideoObjectTrackingInputs trainingTaskInputs =
          AutoMlVideoObjectTrackingInputs.newBuilder().setModelType(ModelType.CLOUD).build();

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineVideoObjectTracking)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(modelToUpload)
              .build();

      TrainingPipeline createTrainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Video Object Tracking Response");
      System.out.format("Name: %s\n", createTrainingPipelineResponse.getName());
      System.out.format("Display Name: %s\n", createTrainingPipelineResponse.getDisplayName());

      System.out.format(
          "Training Task Definition %s\n",
          createTrainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "Training Task Inputs: %s\n",
          createTrainingPipelineResponse.getTrainingTaskInputs().toString());
      System.out.format(
          "Training Task Metadata: %s\n",
          createTrainingPipelineResponse.getTrainingTaskMetadata().toString());

      System.out.format("State: %s\n", createTrainingPipelineResponse.getState().toString());
      System.out.format(
          "Create Time: %s\n", createTrainingPipelineResponse.getCreateTime().toString());
      System.out.format("StartTime %s\n", createTrainingPipelineResponse.getStartTime().toString());
      System.out.format("End Time: %s\n", createTrainingPipelineResponse.getEndTime().toString());
      System.out.format(
          "Update Time: %s\n", createTrainingPipelineResponse.getUpdateTime().toString());
      System.out.format("Labels: %s\n", createTrainingPipelineResponse.getLabelsMap().toString());

      InputDataConfig inputDataConfigResponse = createTrainingPipelineResponse.getInputDataConfig();
      System.out.println("Input Data config");
      System.out.format("Dataset Id: %s\n", inputDataConfigResponse.getDatasetId());
      System.out.format("Annotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfigResponse.getFractionSplit();
      System.out.println("Fraction split");
      System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
      System.out.println("Filter Split");
      System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
      System.out.println("Predefined Split");
      System.out.format("Key: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
      System.out.println("Timestamp Split");
      System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("Key: %s\n", timestampSplit.getKey());

      Model modelResponse = createTrainingPipelineResponse.getModelToUpload();
      System.out.println("Model To Upload");
      System.out.format("Name: %s\n", modelResponse.getName());
      System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
      System.out.format("Description: %s\n", modelResponse.getDescription());
      System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("Metadata: %s\n", modelResponse.getMetadata());

      System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "Supported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList().toString());
      System.out.format(
          "Supported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList().toString());
      System.out.format(
          "Supported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList().toString());

      System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
      System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("Labels: %s\n", modelResponse.getLabelsMap());

      Status status = createTrainingPipelineResponse.getError();
      System.out.println("Error");
      System.out.format("Code: %s\n", status.getCode());
      System.out.format("Message: %s\n", status.getMessage());
    }
  }
}

Node.js

このサンプルを試す前に、Vertex AI クイックスタート: クライアント ライブラリの使用にある Node.js の設定手順を完了してください。詳細については、Vertex AI Node.js API のリファレンス ドキュメントをご覧ください。

Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;
const ModelType = definition.AutoMlVideoObjectTrackingInputs.ModelType;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineVideoObjectTracking() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const trainingTaskInputsObj =
    new definition.AutoMlVideoObjectTrackingInputs({
      modelType: ModelType.CLOUD,
    });
  const trainingTaskInputs = trainingTaskInputsObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId: datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_object_tracking_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline video object tracking response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineVideoObjectTracking();

Vertex AI SDK for Python

Vertex AI SDK for Python のインストールまたは更新の方法については、Vertex AI SDK for Python をインストールするをご覧ください。詳細については、Vertex AI SDK for Python API のリファレンス ドキュメントをご覧ください。

from google.cloud import aiplatform
from google.cloud.aiplatform.gapic.schema import trainingjob


def create_training_pipeline_video_object_tracking_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    model_display_name: str,
    location: str = "us-central1",
    api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": api_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform.gapic.PipelineServiceClient(client_options=client_options)
    training_task_inputs = trainingjob.definition.AutoMlVideoObjectTrackingInputs(
        model_type="CLOUD",
    ).to_value()

    training_pipeline = {
        "display_name": display_name,
        "training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_video_object_tracking_1.0.0.yaml",
        "training_task_inputs": training_task_inputs,
        "input_data_config": {"dataset_id": dataset_id},
        "model_to_upload": {"display_name": model_display_name},
    }
    parent = f"projects/{project}/locations/{location}"
    response = client.create_training_pipeline(
        parent=parent, training_pipeline=training_pipeline
    )
    print("response:", response)

REST を使用してデータ分割を制御する

トレーニング セット、検証セット、テストセットの間でトレーニング データをどのように分割するかを制御できます。Vertex AI API を使用する場合は、Split オブジェクトを使用してデータの分割を決定します。トレーニング データの分割に使用できるオブジェクト タイプの一つとして、Split オブジェクトを InputConfig オブジェクトに含めることができます。選択できるメソッドは 1 つのみです。

  • FractionSplit:
    • TRAINING_FRACTION: トレーニング セットに使用されるトレーニング データの割合。
    • VALIDATION_FRACTION: 検証セットに使用されるトレーニング データの割合。動画データに対しては使用されません。
    • TEST_FRACTION: テストセットに使用されるトレーニング データの割合。

    いずれか一つでも指定する場合は、すべてを指定する必要があります。割合の合計が 1.0 になるようにしてください。割合のデフォルト値は、データ型によって異なります。詳細については、こちらをご覧ください。

    "fractionSplit": {
      "trainingFraction": TRAINING_FRACTION,
      "validationFraction": VALIDATION_FRACTION,
      "testFraction": TEST_FRACTION
    },
    
  • FilterSplit:
    • TRAINING_FILTER: このフィルタに一致するデータ項目がトレーニング セットに使用されます。
    • VALIDATION_FILTER: このフィルタに一致するデータ項目が検証セットに使用されます。動画データの場合は "-" にする必要があります。
    • TEST_FILTER: このフィルタに一致するデータ項目がテストセットに使用されます。

    これらのフィルタは、ml_use ラベル、またはデータに適用された任意のラベルとともに使用できます。ml-use ラベルその他のラベルを使用してデータをフィルタリングする方法をご確認ください。

    次の例は、検証セットを含む、ml_use ラベルを持つ filterSplit オブジェクトの使用方法を示しています。

    "filterSplit": {
    "trainingFilter": "labels.aiplatform.googleapis.com/ml_use=training",
    "validationFilter": "labels.aiplatform.googleapis.com/ml_use=validation",
    "testFilter": "labels.aiplatform.googleapis.com/ml_use=test"
    }