从视频分类模型获取预测结果

本页面介绍了如何使用 Google Cloud 控制台或 Vertex AI API 从视频分类识别模型获取批量预测结果。批量预测是异步请求。您可以直接从模型资源请求批量预测,而无需将模型部署到端点。

AutoML 视频模型不支持在线预测。

进行批量预测

如需发出批量预测请求,请指定输入源和 Vertex AI 存储预测结果所采用的输出格式

输入数据要求

批量请求的输入指定要发送到模型进行预测的内容。AutoML 视频模型类型的批量预测使用 JSON 行文件指定要进行预测的视频列表,然后将 JSON 行文件存储在 Cloud Storage 存储桶中。您可以为 timeSegmentEnd 字段指定 Infinity,以指定视频末尾。以下示例显示了输入 JSON 行文件中的一行。

{'content': 'gs://sourcebucket/datasets/videos/source_video.mp4', 'mimeType': 'video/mp4', 'timeSegmentStart': '0.0s', 'timeSegmentEnd': '2.366667s'}

请求批量预测

对于批量预测请求,您可以使用 Google Cloud 控制台或 Vertex AI API。批量预测任务可能需要一些时间才能完成,具体取决于提交的输入数据项数量。

Google Cloud 控制台

使用 Google Cloud 控制台请求批量预测。

  1. 在 Google Cloud 控制台的 Vertex AI 部分中,前往批量预测页面。

    前往“批量预测”页面

  2. 点击创建以打开新建批量预测窗口,完成以下步骤:

    1. 输入批量预测的名称。
    2. 对于模型名称,选择要用于此批量预测的模型的名称。
    3. 对于来源路径,指定 JSON 行输入文件所在的 Cloud Storage 位置。
    4. 对于目标路径,指定存储批量预测结果的 Cloud Storage 位置。输出格式取决于模型的目标。用于图片目标的 AutoML 模型会输出 JSON 行文件。

API

使用 Vertex AI API 发送批量预测请求。

REST

在使用任何请求数据之前,请先进行以下替换:

  • LOCATION_ID:存储模型和执行批量预测作业的区域。例如 us-central1
  • PROJECT_ID:您的项目 ID
  • BATCH_JOB_NAME:批量作业的显示名
  • MODEL_ID:用于执行预测的模型的 ID
  • THRESHOLD_VALUE(可选):模型仅返回置信度分数至少为此值的预测
  • SEGMENT_CLASSIFICATION(可选):一个布尔值,用于确定是否请求片段级分类。Vertex AI 会返回您在输入实例中指定的视频的整个时间段的标签及其置信度分数。默认值为 true
  • SHOT_CLASSIFICATION(可选):一个布尔值,用于确定是否请求镜头级分类。Vertex AI 确定您在输入实例中指定的视频的整个时间段中每个镜头的边界。然后,Vertex AI 会返回每个检测到的镜头的标签及其置信度分数,以及镜头的开始和结束时间。默认值为 false
  • ONE_SEC_INTERVAL_CLASSIFICATION(可选):一个布尔值,用于确定是否以 1 秒钟为间隔请求视频分类。Vertex AI 会返回您在输入实例中指定的视频的整个时间段中每一秒的标签及其置信度分数。默认值为 false
  • URI:输入 JSON 行文件所在的 Cloud Storage URI。
  • BUCKET:您的 Cloud Storage 存储桶
  • PROJECT_NUMBER:自动生成的项目编号

HTTP 方法和网址:

POST https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs

请求 JSON 正文:

{
    "displayName": "BATCH_JOB_NAME",
    "model": "projects/PROJECT_ID/locations/LOCATION_ID/models/MODEL_ID",
    "modelParameters": {
      "confidenceThreshold": THRESHOLD_VALUE,
      "segmentClassification": SEGMENT_CLASSIFICATION,
      "shotClassification": SHOT_CLASSIFICATION,
      "oneSecIntervalClassification": ONE_SEC_INTERVAL_CLASSIFICATION
    },
    "inputConfig": {
        "instancesFormat": "jsonl",
        "gcsSource": {
            "uris": ["URI"],
        },
    },
    "outputConfig": {
        "predictionsFormat": "jsonl",
        "gcsDestination": {
            "outputUriPrefix": "OUTPUT_BUCKET",
        },
    },
}

如需发送请求,请选择以下方式之一:

curl

将请求正文保存在名为 request.json 的文件中,然后执行以下命令:

curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs"

PowerShell

将请求正文保存在名为 request.json 的文件中,然后执行以下命令:

$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs" | Select-Object -Expand Content

您应该收到类似以下内容的 JSON 响应:

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/batchPredictionJobs/BATCH_JOB_ID",
  "displayName": "BATCH_JOB_NAME",
  "model": "projects/PROJECT_NUMBER/locations/us-central1/models/MODEL_ID",
  "inputConfig": {
    "instancesFormat": "jsonl",
    "gcsSource": {
      "uris": [
        "CONTENT"
      ]
    }
  },
  "outputConfig": {
    "predictionsFormat": "jsonl",
    "gcsDestination": {
      "outputUriPrefix": "BUCKET"
    }
  },
  "state": "JOB_STATE_PENDING",
  "createTime": "2020-05-30T02:58:44.341643Z",
  "updateTime": "2020-05-30T02:58:44.341643Z",
  "modelDisplayName": "MODEL_NAME",
  "modelObjective": "MODEL_OBJECTIVE"
}

您可以使用 BATCH_JOB_ID 轮询批量作业的状态,直到作业 stateJOB_STATE_SUCCEEDED

Java

在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Java 设置说明执行操作。如需了解详情,请参阅 Vertex AI Java API 参考文档

如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.BatchDedicatedResources;
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
import com.google.cloud.aiplatform.v1.BatchPredictionJob.InputConfig;
import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputConfig;
import com.google.cloud.aiplatform.v1.BatchPredictionJob.OutputInfo;
import com.google.cloud.aiplatform.v1.BigQueryDestination;
import com.google.cloud.aiplatform.v1.BigQuerySource;
import com.google.cloud.aiplatform.v1.CompletionStats;
import com.google.cloud.aiplatform.v1.GcsDestination;
import com.google.cloud.aiplatform.v1.GcsSource;
import com.google.cloud.aiplatform.v1.JobServiceClient;
import com.google.cloud.aiplatform.v1.JobServiceSettings;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.MachineSpec;
import com.google.cloud.aiplatform.v1.ManualBatchTuningParameters;
import com.google.cloud.aiplatform.v1.ModelName;
import com.google.cloud.aiplatform.v1.ResourcesConsumed;
import com.google.cloud.aiplatform.v1.schema.predict.params.VideoClassificationPredictionParams;
import com.google.protobuf.Any;
import com.google.protobuf.Value;
import com.google.rpc.Status;
import java.io.IOException;
import java.util.List;

public class CreateBatchPredictionJobVideoClassificationSample {

  public static void main(String[] args) throws IOException {
    String batchPredictionDisplayName = "YOUR_VIDEO_CLASSIFICATION_DISPLAY_NAME";
    String modelId = "YOUR_MODEL_ID";
    String gcsSourceUri =
        "gs://YOUR_GCS_SOURCE_BUCKET/path_to_your_video_source/[file.csv/file.jsonl]";
    String gcsDestinationOutputUriPrefix =
        "gs://YOUR_GCS_SOURCE_BUCKET/destination_output_uri_prefix/";
    String project = "YOUR_PROJECT_ID";
    createBatchPredictionJobVideoClassification(
        batchPredictionDisplayName, modelId, gcsSourceUri, gcsDestinationOutputUriPrefix, project);
  }

  static void createBatchPredictionJobVideoClassification(
      String batchPredictionDisplayName,
      String modelId,
      String gcsSourceUri,
      String gcsDestinationOutputUriPrefix,
      String project)
      throws IOException {
    JobServiceSettings jobServiceSettings =
        JobServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (JobServiceClient jobServiceClient = JobServiceClient.create(jobServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);

      VideoClassificationPredictionParams modelParamsObj =
          VideoClassificationPredictionParams.newBuilder()
              .setConfidenceThreshold(((float) 0.5))
              .setMaxPredictions(10000)
              .setSegmentClassification(true)
              .setShotClassification(true)
              .setOneSecIntervalClassification(true)
              .build();

      Value modelParameters = ValueConverter.toValue(modelParamsObj);

      ModelName modelName = ModelName.of(project, location, modelId);
      GcsSource.Builder gcsSource = GcsSource.newBuilder();
      gcsSource.addUris(gcsSourceUri);
      InputConfig inputConfig =
          InputConfig.newBuilder().setInstancesFormat("jsonl").setGcsSource(gcsSource).build();

      GcsDestination gcsDestination =
          GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build();
      OutputConfig outputConfig =
          OutputConfig.newBuilder()
              .setPredictionsFormat("jsonl")
              .setGcsDestination(gcsDestination)
              .build();

      BatchPredictionJob batchPredictionJob =
          BatchPredictionJob.newBuilder()
              .setDisplayName(batchPredictionDisplayName)
              .setModel(modelName.toString())
              .setModelParameters(modelParameters)
              .setInputConfig(inputConfig)
              .setOutputConfig(outputConfig)
              .build();
      BatchPredictionJob batchPredictionJobResponse =
          jobServiceClient.createBatchPredictionJob(locationName, batchPredictionJob);

      System.out.println("Create Batch Prediction Job Video Classification Response");
      System.out.format("\tName: %s\n", batchPredictionJobResponse.getName());
      System.out.format("\tDisplay Name: %s\n", batchPredictionJobResponse.getDisplayName());
      System.out.format("\tModel %s\n", batchPredictionJobResponse.getModel());
      System.out.format(
          "\tModel Parameters: %s\n", batchPredictionJobResponse.getModelParameters());

      System.out.format("\tState: %s\n", batchPredictionJobResponse.getState());
      System.out.format("\tCreate Time: %s\n", batchPredictionJobResponse.getCreateTime());
      System.out.format("\tStart Time: %s\n", batchPredictionJobResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", batchPredictionJobResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", batchPredictionJobResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", batchPredictionJobResponse.getLabelsMap());

      InputConfig inputConfigResponse = batchPredictionJobResponse.getInputConfig();
      System.out.println("\tInput Config");
      System.out.format("\t\tInstances Format: %s\n", inputConfigResponse.getInstancesFormat());

      GcsSource gcsSourceResponse = inputConfigResponse.getGcsSource();
      System.out.println("\t\tGcs Source");
      System.out.format("\t\t\tUris %s\n", gcsSourceResponse.getUrisList());

      BigQuerySource bigQuerySource = inputConfigResponse.getBigquerySource();
      System.out.println("\t\tBigquery Source");
      System.out.format("\t\t\tInput_uri: %s\n", bigQuerySource.getInputUri());

      OutputConfig outputConfigResponse = batchPredictionJobResponse.getOutputConfig();
      System.out.println("\tOutput Config");
      System.out.format(
          "\t\tPredictions Format: %s\n", outputConfigResponse.getPredictionsFormat());

      GcsDestination gcsDestinationResponse = outputConfigResponse.getGcsDestination();
      System.out.println("\t\tGcs Destination");
      System.out.format(
          "\t\t\tOutput Uri Prefix: %s\n", gcsDestinationResponse.getOutputUriPrefix());

      BigQueryDestination bigQueryDestination = outputConfigResponse.getBigqueryDestination();
      System.out.println("\t\tBig Query Destination");
      System.out.format("\t\t\tOutput Uri: %s\n", bigQueryDestination.getOutputUri());

      BatchDedicatedResources batchDedicatedResources =
          batchPredictionJobResponse.getDedicatedResources();
      System.out.println("\tBatch Dedicated Resources");
      System.out.format(
          "\t\tStarting Replica Count: %s\n", batchDedicatedResources.getStartingReplicaCount());
      System.out.format(
          "\t\tMax Replica Count: %s\n", batchDedicatedResources.getMaxReplicaCount());

      MachineSpec machineSpec = batchDedicatedResources.getMachineSpec();
      System.out.println("\t\tMachine Spec");
      System.out.format("\t\t\tMachine Type: %s\n", machineSpec.getMachineType());
      System.out.format("\t\t\tAccelerator Type: %s\n", machineSpec.getAcceleratorType());
      System.out.format("\t\t\tAccelerator Count: %s\n", machineSpec.getAcceleratorCount());

      ManualBatchTuningParameters manualBatchTuningParameters =
          batchPredictionJobResponse.getManualBatchTuningParameters();
      System.out.println("\tManual Batch Tuning Parameters");
      System.out.format("\t\tBatch Size: %s\n", manualBatchTuningParameters.getBatchSize());

      OutputInfo outputInfo = batchPredictionJobResponse.getOutputInfo();
      System.out.println("\tOutput Info");
      System.out.format("\t\tGcs Output Directory: %s\n", outputInfo.getGcsOutputDirectory());
      System.out.format("\t\tBigquery Output Dataset: %s\n", outputInfo.getBigqueryOutputDataset());

      Status status = batchPredictionJobResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
      List<Any> details = status.getDetailsList();

      for (Status partialFailure : batchPredictionJobResponse.getPartialFailuresList()) {
        System.out.println("\tPartial Failure");
        System.out.format("\t\tCode: %s\n", partialFailure.getCode());
        System.out.format("\t\tMessage: %s\n", partialFailure.getMessage());
        List<Any> partialFailureDetailsList = partialFailure.getDetailsList();
      }

      ResourcesConsumed resourcesConsumed = batchPredictionJobResponse.getResourcesConsumed();
      System.out.println("\tResources Consumed");
      System.out.format("\t\tReplica Hours: %s\n", resourcesConsumed.getReplicaHours());

      CompletionStats completionStats = batchPredictionJobResponse.getCompletionStats();
      System.out.println("\tCompletion Stats");
      System.out.format("\t\tSuccessful Count: %s\n", completionStats.getSuccessfulCount());
      System.out.format("\t\tFailed Count: %s\n", completionStats.getFailedCount());
      System.out.format("\t\tIncomplete Count: %s\n", completionStats.getIncompleteCount());
    }
  }
}

Node.js

在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Node.js 设置说明执行操作。如需了解详情,请参阅 Vertex AI Node.js API 参考文档

如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const batchPredictionDisplayName = 'YOUR_BATCH_PREDICTION_DISPLAY_NAME';
// const modelId = 'YOUR_MODEL_ID';
// const gcsSourceUri = 'YOUR_GCS_SOURCE_URI';
// const gcsDestinationOutputUriPrefix = 'YOUR_GCS_DEST_OUTPUT_URI_PREFIX';
//    eg. "gs://<your-gcs-bucket>/destination_path"
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {params} = aiplatform.protos.google.cloud.aiplatform.v1.schema.predict;

// Imports the Google Cloud Job Service Client library
const {JobServiceClient} = require('@google-cloud/aiplatform').v1;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const jobServiceClient = new JobServiceClient(clientOptions);

async function createBatchPredictionJobVideoClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;
  const modelName = `projects/${project}/locations/${location}/models/${modelId}`;

  // For more information on how to configure the model parameters object, see
  // https://cloud.google.com/ai-platform-unified/docs/predictions/batch-predictions
  const modelParamsObj = new params.VideoClassificationPredictionParams({
    confidenceThreshold: 0.5,
    maxPredictions: 1000,
    segmentClassification: true,
    shotClassification: true,
    oneSecIntervalClassification: true,
  });

  const modelParameters = modelParamsObj.toValue();

  const inputConfig = {
    instancesFormat: 'jsonl',
    gcsSource: {uris: [gcsSourceUri]},
  };
  const outputConfig = {
    predictionsFormat: 'jsonl',
    gcsDestination: {outputUriPrefix: gcsDestinationOutputUriPrefix},
  };
  const batchPredictionJob = {
    displayName: batchPredictionDisplayName,
    model: modelName,
    modelParameters,
    inputConfig,
    outputConfig,
  };
  const request = {
    parent,
    batchPredictionJob,
  };

  // Create batch prediction job request
  const [response] = await jobServiceClient.createBatchPredictionJob(request);

  console.log('Create batch prediction job video classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createBatchPredictionJobVideoClassification();

Python

如需了解如何安装或更新 Python,请参阅安装 Python 版 Vertex AI SDK。如需了解详情,请参阅 Python API 参考文档

def create_batch_prediction_job_sample(
    project: str,
    location: str,
    model_resource_name: str,
    job_display_name: str,
    gcs_source: Union[str, Sequence[str]],
    gcs_destination: str,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    my_model = aiplatform.Model(model_resource_name)

    batch_prediction_job = my_model.batch_predict(
        job_display_name=job_display_name,
        gcs_source=gcs_source,
        gcs_destination_prefix=gcs_destination,
        sync=sync,
    )

    batch_prediction_job.wait()

    print(batch_prediction_job.display_name)
    print(batch_prediction_job.resource_name)
    print(batch_prediction_job.state)
    return batch_prediction_job

检索批量预测结果

Vertex AI 将批量预测输出发送到您指定的目标位置。

批量预测任务完成后,预测的输出存储在您在请求中指定的 Cloud Storage 存储桶中。

批量预测结果示例

以下示例演示如何从视频分类模型获取批量预测结果。

{
  "instance": {
   "content": "gs://bucket/video.mp4",
    "mimeType": "video/mp4",
    "timeSegmentStart": "1s",
    "timeSegmentEnd": "5s"
  }
  "prediction": [{
    "id": "1",
    "displayName": "cat",
    "type": "segment-classification",
    "timeSegmentStart": "1s",
    "timeSegmentEnd": "5s",
    "confidence": 0.7
  }, {
    "id": "1",
    "displayName": "cat",
    "type": "shot-classification",
    "timeSegmentStart": "1s",
    "timeSegmentEnd": "4s",
    "confidence": 0.9
  }, {
    "id": "2",
    "displayName": "dog",
    "type": "shot-classification",
    "timeSegmentStart": "4s",
    "timeSegmentEnd": "5s",
    "confidence": 0.6
  }, {
    "id": "1",
    "displayName": "cat",
    "type": "one-sec-interval-classification",
    "timeSegmentStart": "1s",
    "timeSegmentEnd": "1s",
    "confidence": 0.95
  }, {
    "id": "1",
    "displayName": "cat",
    "type": "one-sec-interval-classification",
    "timeSegmentStart": "2s",
    "timeSegmentEnd": "2s",
    "confidence": 0.9
  }, {
    "id": "1",
    "displayName": "cat",
    "type": "one-sec-interval-classification",
    "timeSegmentStart": "3s",
    "timeSegmentEnd": "3s",
    "confidence": 0.85
  }, {
    "id": "2",
    "displayName": "dog",
    "type": "one-sec-interval-classification",
    "timeSegmentStart": "4s",
    "timeSegmentEnd": "4s",
    "confidence": 0.6
  }]
}