AutoML Vision API 教程

本教程演示了如何使用 AutoML Vision 新建一个采用您自己的一组训练图片的模型、评估结果以及预测测试图片的分类。

本教程使用一个包含五种不同花卉(向日葵、郁金香、雏菊、玫瑰和蒲公英)图片的数据集,并介绍了如何训练自定义模型、评估模型性能以及使用自定义模型对新图片进行分类。

前提条件

配置项目环境

  1. 登录您的 Google Cloud 账号。如果您是 Google Cloud 新手,请创建一个账号来评估我们的产品在实际场景中的表现。新客户还可获享 $300 赠金,用于运行、测试和部署工作负载。
  2. 在 Google Cloud Console 中的项目选择器页面上,选择或创建一个 Google Cloud 项目

    转到“项目选择器”

  3. 确保您的 Google Cloud 项目已启用结算功能

  4. 启用 AutoML Vision API。

    启用 API

  5. 在 Google Cloud Console 中的项目选择器页面上,选择或创建一个 Google Cloud 项目

    转到“项目选择器”

  6. 确保您的 Google Cloud 项目已启用结算功能

  7. 启用 AutoML Vision API。

    启用 API

  8. 安装 Google Cloud CLI
  9. 按照说明创建服务账号并下载密钥文件
  10. GOOGLE_APPLICATION_CREDENTIALS 环境变量设置为如下路径:指向您在创建服务账号时所下载的服务账号密钥文件的路径。例如:
    export GOOGLE_APPLICATION_CREDENTIALS=KEY_FILE
  11. 使用以下命令将您的新服务账号添加到 IAM 角色 AutoML Editor。将 PROJECT_ID 替换为 Google Cloud 项目的名称,并将 SERVICE_ACCOUNT_NAME 替换为新服务账号的名称,例如 service-account1@myproject.iam.gserviceaccount.com
    gcloud auth login
    gcloud projects add-iam-policy-binding PROJECT_ID \
     --member="user:YOUR_USERID@YOUR_DOMAIN" \
     --role="roles/automl.admin"
    gcloud projects add-iam-policy-binding PROJECT_ID \
     --member=serviceAccount:SERVICE_ACCOUNT_NAME \
     --role="roles/automl.editor"
  12. 允许 AutoML Vision 服务账号访问您的 Google Cloud 项目资源:
    gcloud projects add-iam-policy-binding PROJECT_ID \
     --member="serviceAccount:custom-vision@appspot.gserviceaccount.com" \
     --role="roles/storage.admin"
  13. 安装客户端库
  14. 设置 PROJECT_IDREGION_NAME 环境变量。

    PROJECT_ID 替换为 Google Cloud 项目的 ID。AutoML Vision 目前要求将位置设为 us-central1
    export PROJECT_ID="PROJECT_ID"
    export REGION_NAME="us-central1"
  15. 创建一个 Cloud Storage 存储桶,以存储您将用于训练自定义模型的文档。

    存储桶名称必须采用以下格式: $PROJECT_ID-vcm。 以下命令将在 us-central1 地区中创建一个名为 $PROJECT_ID-vcm 的存储桶。
    gsutil mb -p PROJECT_ID -c regional -l us-central1 gs://PROJECT_ID-vcm/
  16. 设置 BUCKET 变量。
    export BUCKET=PROJECT_ID-vcm
  17. 将公开提供的花卉图片数据集从 gs://cloud-samples-data/img/flower_photos/ 复制到您的 Google Cloud Storage 存储桶。

    在 Cloud Shell 会话中,输入以下内容:
    gsutil -m cp -R gs://cloud-samples-data/ai-platform/flowers/ gs://BUCKET/img/

    文件复制操作需要大约 20 分钟才能完成。

    此命令还会复制 all_data.csv 文件,其中列出了原始文件名及其标签。

  18. 示例数据集包含一个 CSV 文件,其中包含每个图片的位置和标签(如需了解所需格式的详情,请参阅准备训练数据)。更新 CSV 文件,使其指向您自己的存储桶中的文件:
    gsutil cat gs://BUCKET/img/flowers/all_data.csv | sed "s:cloud-ml-data/img/flower_photos/:BUCKET/img/flowers/:" > all_data.csv
    然后将更新的 CSV 文件复制到您的存储桶中:
    gsutil cp all_data.csv gs://BUCKET/csv/

源代码文件位置

您可以从下面提供的位置下载源代码。下载后,您可以将源代码复制到 AutoML Vision 项目文件夹中。

Python

本教程包含以下 Python 文件

Java

本教程包含以下 Java 文件

Node.js

本教程包含以下 Node.js 程序

运行应用

第 1 步:创建 Flowers 数据集

如要创建自定义模型,首先需要创建一个空数据集,该数据集最终用于保存模型的训练数据。创建数据集时,您可以指定希望自定义模型执行的分类类型:

  • 多类别 (MULTICLASS) 分类为每个分类的图片分配一个标签
  • 多标签 (MULTILABEL) 分类可以为每个图片分配多个标签

本教程会创建一个名为 flowers 的数据集并使用“多类别”(MULTICLASS) 分类。

复制代码

Python

from google.cloud import automl

# TODO(developer): Uncomment and set the following variables
# project_id = "YOUR_PROJECT_ID"
# display_name = "your_datasets_display_name"

client = automl.AutoMlClient()

# A resource that represents Google Cloud Platform location.
project_location = f"projects/{project_id}/locations/us-central1"
# Specify the classification type
# Types:
# MultiLabel: Multiple labels are allowed for one example.
# MultiClass: At most one label is allowed per example.
# https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#classificationtype
metadata = automl.ImageClassificationDatasetMetadata(
    classification_type=automl.ClassificationType.MULTILABEL
)
dataset = automl.Dataset(
    display_name=display_name,
    image_classification_dataset_metadata=metadata,
)

# Create a dataset with the dataset metadata in the region.
response = client.create_dataset(
    parent=project_location, dataset=dataset, timeout=300
)

created_dataset = response.result()

# Display the dataset information
print(f"Dataset name: {created_dataset.name}")
print("Dataset id: {}".format(created_dataset.name.split("/")[-1]))

Java

import com.google.api.gax.longrunning.OperationFuture;
import com.google.cloud.automl.v1.AutoMlClient;
import com.google.cloud.automl.v1.ClassificationType;
import com.google.cloud.automl.v1.Dataset;
import com.google.cloud.automl.v1.ImageClassificationDatasetMetadata;
import com.google.cloud.automl.v1.LocationName;
import com.google.cloud.automl.v1.OperationMetadata;
import java.io.IOException;
import java.util.concurrent.ExecutionException;

class VisionClassificationCreateDataset {

  public static void main(String[] args)
      throws IOException, ExecutionException, InterruptedException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String displayName = "YOUR_DATASET_NAME";
    createDataset(projectId, displayName);
  }

  // Create a dataset
  static void createDataset(String projectId, String displayName)
      throws IOException, ExecutionException, InterruptedException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (AutoMlClient client = AutoMlClient.create()) {
      // A resource that represents Google Cloud Platform location.
      LocationName projectLocation = LocationName.of(projectId, "us-central1");

      // Specify the classification type
      // Types:
      // MultiLabel: Multiple labels are allowed for one example.
      // MultiClass: At most one label is allowed per example.
      ClassificationType classificationType = ClassificationType.MULTILABEL;
      ImageClassificationDatasetMetadata metadata =
          ImageClassificationDatasetMetadata.newBuilder()
              .setClassificationType(classificationType)
              .build();
      Dataset dataset =
          Dataset.newBuilder()
              .setDisplayName(displayName)
              .setImageClassificationDatasetMetadata(metadata)
              .build();
      OperationFuture<Dataset, OperationMetadata> future =
          client.createDatasetAsync(projectLocation, dataset);

      Dataset createdDataset = future.get();

      // Display the dataset information.
      System.out.format("Dataset name: %s\n", createdDataset.getName());
      // To get the dataset id, you have to parse it out of the `name` field. As dataset Ids are
      // required for other methods.
      // Name Form: `projects/{project_id}/locations/{location_id}/datasets/{dataset_id}`
      String[] names = createdDataset.getName().split("/");
      String datasetId = names[names.length - 1];
      System.out.format("Dataset id: %s\n", datasetId);
    }
  }
}

Node.js

如需了解详情,请参阅 AutoML Vision Node.js API 参考文档

如需向 AutoML Vision 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证

/**
 * TODO(developer): Uncomment these variables before running the sample.
 */
// const projectId = 'YOUR_PROJECT_ID';
// const location = 'us-central1';
// const displayName = 'YOUR_DISPLAY_NAME';

// Imports the Google Cloud AutoML library
const {AutoMlClient} = require('@google-cloud/automl').v1;

// Instantiates a client
const client = new AutoMlClient();

async function createDataset() {
  // Construct request
  // Specify the classification type
  // Types:
  // MultiLabel: Multiple labels are allowed for one example.
  // MultiClass: At most one label is allowed per example.
  const request = {
    parent: client.locationPath(projectId, location),
    dataset: {
      displayName: displayName,
      imageClassificationDatasetMetadata: {
        classificationType: 'MULTILABEL',
      },
    },
  };

  // Create dataset
  const [operation] = await client.createDataset(request);

  // Wait for operation to complete.
  const [response] = await operation.promise();

  console.log(`Dataset name: ${response.name}`);
  console.log(`
    Dataset id: ${
      response.name
        .split('/')
        [response.name.split('/').length - 1].split('\n')[0]
    }`);
}

createDataset();

请求

运行 create_dataset 函数以创建空数据集。您必须修改以下代码行:

  • project_id 设置为您的 PROJECT_ID
  • 为数据集 (flowers) 设置 display_name
  • MULTILABEL 更改为 MULTICLASS

Python

python3 vision_classification_create_dataset.py

Java

mvn compile exec:java -Dexec.mainClass="com.example.automl.VisionClassificationCreateDataset"

Node.js

node vision_classification_create_dataset.js

响应

响应中包含新创建的数据集的详细信息,其中包括数据集 ID(用于在以后的请求中引用该数据集)。建议将环境变量 DATASET_ID 设置为返回的数据集 ID 值。

Dataset name: projects/216065747626/locations/us-central1/datasets/ICN7372141011130533778
Dataset id: ICN7372141011130533778
Dataset display name: flowers
Image classification dataset specification:
       classification_type: MULTICLASS
Dataset example count: 0
Dataset create time:
       seconds: 1530251987
       nanos: 216586000

第 2 步:将图片导入数据集

接下来使用带有目标标签的训练图片填充数据集。

import_data 函数接口接受 CSV 文件作为输入,该文件列出了所有训练图片的位置以及每个训练图片的正确标签。(如需详细了解所需格式,请参阅准备数据。)在本教程中,我们将使用您复制到 Cloud Storage 存储桶的加标签图片,这些图片列在 gs://$PROJECT_ID-vcm/csv/all_data.csv 中。

复制代码

Python

from google.cloud import automl

# TODO(developer): Uncomment and set the following variables
# project_id = "YOUR_PROJECT_ID"
# dataset_id = "YOUR_DATASET_ID"
# path = "gs://YOUR_BUCKET_ID/path/to/data.csv"

client = automl.AutoMlClient()
# Get the full path of the dataset.
dataset_full_id = client.dataset_path(project_id, "us-central1", dataset_id)
# Get the multiple Google Cloud Storage URIs
input_uris = path.split(",")
gcs_source = automl.GcsSource(input_uris=input_uris)
input_config = automl.InputConfig(gcs_source=gcs_source)
# Import data from the input URI
response = client.import_data(name=dataset_full_id, input_config=input_config)

print("Processing import...")
print(f"Data imported. {response.result()}")

Java

import com.google.api.gax.longrunning.OperationFuture;
import com.google.cloud.automl.v1.AutoMlClient;
import com.google.cloud.automl.v1.DatasetName;
import com.google.cloud.automl.v1.GcsSource;
import com.google.cloud.automl.v1.InputConfig;
import com.google.cloud.automl.v1.OperationMetadata;
import com.google.protobuf.Empty;
import java.io.IOException;
import java.util.Arrays;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

class ImportDataset {

  public static void main(String[] args)
      throws IOException, ExecutionException, InterruptedException, TimeoutException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String path = "gs://BUCKET_ID/path_to_training_data.csv";
    importDataset(projectId, datasetId, path);
  }

  // Import a dataset
  static void importDataset(String projectId, String datasetId, String path)
      throws IOException, ExecutionException, InterruptedException, TimeoutException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (AutoMlClient client = AutoMlClient.create()) {
      // Get the complete path of the dataset.
      DatasetName datasetFullId = DatasetName.of(projectId, "us-central1", datasetId);

      // Get multiple Google Cloud Storage URIs to import data from
      GcsSource gcsSource =
          GcsSource.newBuilder().addAllInputUris(Arrays.asList(path.split(","))).build();

      // Import data from the input URI
      InputConfig inputConfig = InputConfig.newBuilder().setGcsSource(gcsSource).build();
      System.out.println("Processing import...");

      // Start the import job
      OperationFuture<Empty, OperationMetadata> operation =
          client.importDataAsync(datasetFullId, inputConfig);

      System.out.format("Operation name: %s%n", operation.getName());

      // If you want to wait for the operation to finish, adjust the timeout appropriately. The
      // operation will still run if you choose not to wait for it to complete. You can check the
      // status of your operation using the operation's name.
      Empty response = operation.get(45, TimeUnit.MINUTES);
      System.out.format("Dataset imported. %s%n", response);
    } catch (TimeoutException e) {
      System.out.println("The operation's polling period was not long enough.");
      System.out.println("You can use the Operation's name to get the current status.");
      System.out.println("The import job is still running and will complete as expected.");
      throw e;
    }
  }
}

Node.js

如需了解详情,请参阅 AutoML Vision Node.js API 参考文档

如需向 AutoML Vision 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证

/**
 * TODO(developer): Uncomment these variables before running the sample.
 */
// const projectId = 'YOUR_PROJECT_ID';
// const location = 'us-central1';
// const datasetId = 'YOUR_DISPLAY_ID';
// const path = 'gs://BUCKET_ID/path_to_training_data.csv';

// Imports the Google Cloud AutoML library
const {AutoMlClient} = require('@google-cloud/automl').v1;

// Instantiates a client
const client = new AutoMlClient();

async function importDataset() {
  // Construct request
  const request = {
    name: client.datasetPath(projectId, location, datasetId),
    inputConfig: {
      gcsSource: {
        inputUris: path.split(','),
      },
    },
  };

  // Import dataset
  console.log('Proccessing import');
  const [operation] = await client.importData(request);

  // Wait for operation to complete.
  const [response] = await operation.promise();
  console.log(`Dataset imported: ${response}`);
}

importDataset();

请求

运行 import_data 函数以导入训练内容。需要更改的第一部分代码是上一步中的数据集 ID,第二部分代码是 all_data.csv 的 URI。您必须修改以下代码行:

  • project_id 设置为您的 PROJECT_ID
  • 为数据集设置 dataset_id(来自上一步的输出)
  • 设置 path,这是 gs://YOUR_PROJECT_ID-vcm/csv/all_data.csv 的 URI

  • python3 import_dataset.py {Python}

  • mvn compile exec:java -Dexec.mainClass="com.example.automl.ImportDataset" {Java}

  • node import_dataset.js {Node.js}

响应

Processing import...
Dataset imported.

第 3 步:创建(训练)模型

拥有包含已加标签的训练图片的数据集后,您就可以训练新模型了。

复制代码

Python

from google.cloud import automl

# TODO(developer): Uncomment and set the following variables
# project_id = "YOUR_PROJECT_ID"
# dataset_id = "YOUR_DATASET_ID"
# display_name = "your_models_display_name"

client = automl.AutoMlClient()

# A resource that represents Google Cloud Platform location.
project_location = f"projects/{project_id}/locations/us-central1"
# Leave model unset to use the default base model provided by Google
# train_budget_milli_node_hours: The actual train_cost will be equal or
# less than this value.
# https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#imageclassificationmodelmetadata
metadata = automl.ImageClassificationModelMetadata(
    train_budget_milli_node_hours=24000
)
model = automl.Model(
    display_name=display_name,
    dataset_id=dataset_id,
    image_classification_model_metadata=metadata,
)

# Create a model with the model metadata in the region.
response = client.create_model(parent=project_location, model=model)

print(f"Training operation name: {response.operation.name}")
print("Training started...")

Java

import com.google.api.gax.longrunning.OperationFuture;
import com.google.cloud.automl.v1.AutoMlClient;
import com.google.cloud.automl.v1.ImageClassificationModelMetadata;
import com.google.cloud.automl.v1.LocationName;
import com.google.cloud.automl.v1.Model;
import com.google.cloud.automl.v1.OperationMetadata;
import java.io.IOException;
import java.util.concurrent.ExecutionException;

class VisionClassificationCreateModel {

  public static void main(String[] args)
      throws IOException, ExecutionException, InterruptedException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String displayName = "YOUR_DATASET_NAME";
    createModel(projectId, datasetId, displayName);
  }

  // Create a model
  static void createModel(String projectId, String datasetId, String displayName)
      throws IOException, ExecutionException, InterruptedException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (AutoMlClient client = AutoMlClient.create()) {
      // A resource that represents Google Cloud Platform location.
      LocationName projectLocation = LocationName.of(projectId, "us-central1");
      // Set model metadata.
      ImageClassificationModelMetadata metadata =
          ImageClassificationModelMetadata.newBuilder().setTrainBudgetMilliNodeHours(24000).build();
      Model model =
          Model.newBuilder()
              .setDisplayName(displayName)
              .setDatasetId(datasetId)
              .setImageClassificationModelMetadata(metadata)
              .build();

      // Create a model with the model metadata in the region.
      OperationFuture<Model, OperationMetadata> future =
          client.createModelAsync(projectLocation, model);
      // OperationFuture.get() will block until the model is created, which may take several hours.
      // You can use OperationFuture.getInitialFuture to get a future representing the initial
      // response to the request, which contains information while the operation is in progress.
      System.out.format("Training operation name: %s\n", future.getInitialFuture().get().getName());
      System.out.println("Training started...");
    }
  }
}

Node.js

如需了解详情,请参阅 AutoML Vision Node.js API 参考文档

如需向 AutoML Vision 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证

/**
 * TODO(developer): Uncomment these variables before running the sample.
 */
// const projectId = 'YOUR_PROJECT_ID';
// const location = 'us-central1';
// const dataset_id = 'YOUR_DATASET_ID';
// const displayName = 'YOUR_DISPLAY_NAME';

// Imports the Google Cloud AutoML library
const {AutoMlClient} = require('@google-cloud/automl').v1;

// Instantiates a client
const client = new AutoMlClient();

async function createModel() {
  // Construct request
  const request = {
    parent: client.locationPath(projectId, location),
    model: {
      displayName: displayName,
      datasetId: datasetId,
      imageClassificationModelMetadata: {
        trainBudgetMilliNodeHours: 24000,
      },
    },
  };

  // Don't wait for the LRO
  const [operation] = await client.createModel(request);
  console.log(`Training started... ${operation}`);
  console.log(`Training operation name: ${operation.name}`);
}

createModel();

请求

调用 create_model 函数以创建模型。数据集 ID 来自之前的步骤。您必须修改以下代码行:

  • project_id 设置为您的 PROJECT_ID
  • 为数据集设置 dataset_id(来自上一步的输出)
  • 为模型 (flowers_model) 设置 display_name

  • python3 vision_classification_create_model.py {Python}

  • mvn compile exec:java -Dexec.mainClass="com.example.automl.VisionClassificationCreateModel" {Java}

  • node vision_classification_create_model.js {Node.js}

响应

create_model 函数开始训练操作并输出操作名称。训练是异步进行的,可能需要一段时间才能完成,因此您可以使用操作 ID 来检查训练状态。训练完成后,create_model 会返回模型 ID。与数据集 ID 一样,我们建议您将环境变量 MODEL_ID 设置为返回的模型 ID 值。

Training operation name: projects/216065747626/locations/us-central1/operations/ICN3007727620979824033
Training started...
Model name: projects/216065747626/locations/us-central1/models/ICN7683346839371803263
Model id: ICN7683346839371803263
Model display name: flowers_model
Image classification model metadata:
Training budget: 1
Training cost: 1
Stop reason:
Base model id:
Model create time:
        seconds: 1529649600
        nanos: 966000000
Model deployment state: deployed

第 4 步:评估模型

训练结束后,您可以通过查看模型的精确率召回率F1 得分来评估其就绪情况。

display_evaluation 函数接受模型 ID 作为参数。

复制代码

Python

from google.cloud import automl

# TODO(developer): Uncomment and set the following variables
# project_id = "YOUR_PROJECT_ID"
# model_id = "YOUR_MODEL_ID"

client = automl.AutoMlClient()
# Get the full path of the model.
model_full_id = client.model_path(project_id, "us-central1", model_id)

print("List of model evaluations:")
for evaluation in client.list_model_evaluations(parent=model_full_id, filter=""):
    print(f"Model evaluation name: {evaluation.name}")
    print(f"Model annotation spec id: {evaluation.annotation_spec_id}")
    print(f"Create Time: {evaluation.create_time}")
    print(f"Evaluation example count: {evaluation.evaluated_example_count}")
    print(
        "Classification model evaluation metrics: {}".format(
            evaluation.classification_evaluation_metrics
        )
    )

Java


import com.google.cloud.automl.v1.AutoMlClient;
import com.google.cloud.automl.v1.ListModelEvaluationsRequest;
import com.google.cloud.automl.v1.ModelEvaluation;
import com.google.cloud.automl.v1.ModelName;
import java.io.IOException;

class ListModelEvaluations {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String modelId = "YOUR_MODEL_ID";
    listModelEvaluations(projectId, modelId);
  }

  // List model evaluations
  static void listModelEvaluations(String projectId, String modelId) throws IOException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (AutoMlClient client = AutoMlClient.create()) {
      // Get the full path of the model.
      ModelName modelFullId = ModelName.of(projectId, "us-central1", modelId);
      ListModelEvaluationsRequest modelEvaluationsrequest =
          ListModelEvaluationsRequest.newBuilder().setParent(modelFullId.toString()).build();

      // List all the model evaluations in the model by applying filter.
      System.out.println("List of model evaluations:");
      for (ModelEvaluation modelEvaluation :
          client.listModelEvaluations(modelEvaluationsrequest).iterateAll()) {

        System.out.format("Model Evaluation Name: %s\n", modelEvaluation.getName());
        System.out.format("Model Annotation Spec Id: %s", modelEvaluation.getAnnotationSpecId());
        System.out.println("Create Time:");
        System.out.format("\tseconds: %s\n", modelEvaluation.getCreateTime().getSeconds());
        System.out.format("\tnanos: %s", modelEvaluation.getCreateTime().getNanos() / 1e9);
        System.out.format(
            "Evalution Example Count: %d\n", modelEvaluation.getEvaluatedExampleCount());
        System.out.format(
            "Classification Model Evaluation Metrics: %s\n",
            modelEvaluation.getClassificationEvaluationMetrics());
      }
    }
  }
}

Node.js

如需了解详情,请参阅 AutoML Vision Node.js API 参考文档

如需向 AutoML Vision 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证

/**
 * TODO(developer): Uncomment these variables before running the sample.
 */
// const projectId = 'YOUR_PROJECT_ID';
// const location = 'us-central1';
// const modelId = 'YOUR_MODEL_ID';

// Imports the Google Cloud AutoML library
const {AutoMlClient} = require('@google-cloud/automl').v1;

// Instantiates a client
const client = new AutoMlClient();

async function listModelEvaluations() {
  // Construct request
  const request = {
    parent: client.modelPath(projectId, location, modelId),
    filter: '',
  };

  const [response] = await client.listModelEvaluations(request);

  console.log('List of model evaluations:');
  for (const evaluation of response) {
    console.log(`Model evaluation name: ${evaluation.name}`);
    console.log(`Model annotation spec id: ${evaluation.annotationSpecId}`);
    console.log(`Model display name: ${evaluation.displayName}`);
    console.log('Model create time');
    console.log(`\tseconds ${evaluation.createTime.seconds}`);
    console.log(`\tnanos ${evaluation.createTime.nanos / 1e9}`);
    console.log(
      `Evaluation example count: ${evaluation.evaluatedExampleCount}`
    );
    console.log(
      `Classification model evaluation metrics: ${evaluation.classificationEvaluationMetrics}`
    );
  }
}

listModelEvaluations();

请求

执行以下请求,以发出显示模型整体评估性能的请求。您必须修改以下代码行:

  • project_id 设置为您的 PROJECT_ID
  • model_id 设置为您的模型 ID

  • python3 list_model_evaluations.py {Python}

  • mvn compile exec:java -Dexec.mainClass="com.example.automl.ListModelEvaluations" {Java}

  • node list_model_evaluations.js {Node.js}

响应

如果精确率和召回率分数过低,您可以强化训练数据集并重新训练模型。如需了解详情,请参阅评估模型

Precision and recall are based on a score threshold of 0.5
Model Precision: 96.3%
Model Recall: 95.7%
Model F1 score: 96.0%
Model Precision@1: 96.33%
Model Recall@1: 95.74%
Model F1 score@1: 96.04%

第 5 步:使用模型进行预测

如果自定义模型符合您的质量标准,您就可以使用该模型来对新的花卉图片进行分类了。

复制代码

Python

from google.cloud import automl

# TODO(developer): Uncomment and set the following variables
# project_id = "YOUR_PROJECT_ID"
# model_id = "YOUR_MODEL_ID"
# file_path = "path_to_local_file.jpg"

prediction_client = automl.PredictionServiceClient()

# Get the full path of the model.
model_full_id = automl.AutoMlClient.model_path(project_id, "us-central1", model_id)

# Read the file.
with open(file_path, "rb") as content_file:
    content = content_file.read()

image = automl.Image(image_bytes=content)
payload = automl.ExamplePayload(image=image)

# params is additional domain-specific parameters.
# score_threshold is used to filter the result
# https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#predictrequest
params = {"score_threshold": "0.8"}

request = automl.PredictRequest(name=model_full_id, payload=payload, params=params)
response = prediction_client.predict(request=request)

print("Prediction results:")
for result in response.payload:
    print(f"Predicted class name: {result.display_name}")
    print(f"Predicted class score: {result.classification.score}")

Java

import com.google.cloud.automl.v1.AnnotationPayload;
import com.google.cloud.automl.v1.ExamplePayload;
import com.google.cloud.automl.v1.Image;
import com.google.cloud.automl.v1.ModelName;
import com.google.cloud.automl.v1.PredictRequest;
import com.google.cloud.automl.v1.PredictResponse;
import com.google.cloud.automl.v1.PredictionServiceClient;
import com.google.protobuf.ByteString;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;

class VisionClassificationPredict {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String modelId = "YOUR_MODEL_ID";
    String filePath = "path_to_local_file.jpg";
    predict(projectId, modelId, filePath);
  }

  static void predict(String projectId, String modelId, String filePath) throws IOException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PredictionServiceClient client = PredictionServiceClient.create()) {
      // Get the full path of the model.
      ModelName name = ModelName.of(projectId, "us-central1", modelId);
      ByteString content = ByteString.copyFrom(Files.readAllBytes(Paths.get(filePath)));
      Image image = Image.newBuilder().setImageBytes(content).build();
      ExamplePayload payload = ExamplePayload.newBuilder().setImage(image).build();
      PredictRequest predictRequest =
          PredictRequest.newBuilder()
              .setName(name.toString())
              .setPayload(payload)
              .putParams(
                  "score_threshold", "0.8") // [0.0-1.0] Only produce results higher than this value
              .build();

      PredictResponse response = client.predict(predictRequest);

      for (AnnotationPayload annotationPayload : response.getPayloadList()) {
        System.out.format("Predicted class name: %s\n", annotationPayload.getDisplayName());
        System.out.format(
            "Predicted class score: %.2f\n", annotationPayload.getClassification().getScore());
      }
    }
  }
}

Node.js

如需了解详情,请参阅 AutoML Vision Node.js API 参考文档

如需向 AutoML Vision 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证

/**
 * TODO(developer): Uncomment these variables before running the sample.
 */
// const projectId = 'YOUR_PROJECT_ID';
// const location = 'us-central1';
// const modelId = 'YOUR_MODEL_ID';
// const filePath = 'path_to_local_file.jpg';

// Imports the Google Cloud AutoML library
const {PredictionServiceClient} = require('@google-cloud/automl').v1;
const fs = require('fs');

// Instantiates a client
const client = new PredictionServiceClient();

// Read the file content for translation.
const content = fs.readFileSync(filePath);

async function predict() {
  // Construct request
  // params is additional domain-specific parameters.
  // score_threshold is used to filter the result
  const request = {
    name: client.modelPath(projectId, location, modelId),
    payload: {
      image: {
        imageBytes: content,
      },
    },
  };

  const [response] = await client.predict(request);

  for (const annotationPayload of response.payload) {
    console.log(`Predicted class name: ${annotationPayload.displayName}`);
    console.log(
      `Predicted class score: ${annotationPayload.classification.score}`
    );
  }
}

predict();

请求

对于 predict 函数,您必须修改以下代码行:

  • project_id 设置为您的 PROJECT_ID
  • model_id 设置为您的模型 ID
  • file_path 设置为下载的文件(“resources/test.png”)

  • python3 vision_classification_predict.py {Python}

  • mvn compile exec:java -Dexec.mainClass="com.example.automl.VisionClassificationPredict" {Java}

  • node vision_classification_predict.js {Node.js}

响应

该函数返回高于规定的置信度阈值 (0.7) 的分类得分(表示图片与各类别的匹配程度)。

Prediction results:
Predicted class name: dandelion
Predicted class score: 0.9702693223953247

第 6 步:删除模型

使用完此示例模型后,您可以永久删除该模型。但之后您将无法再使用该模型进行预测。

复制代码

Python

from google.cloud import automl

# TODO(developer): Uncomment and set the following variables
# project_id = "YOUR_PROJECT_ID"
# model_id = "YOUR_MODEL_ID"

client = automl.AutoMlClient()
# Get the full path of the model.
model_full_id = client.model_path(project_id, "us-central1", model_id)
response = client.delete_model(name=model_full_id)

print(f"Model deleted. {response.result()}")

Java

import com.google.cloud.automl.v1.AutoMlClient;
import com.google.cloud.automl.v1.ModelName;
import com.google.protobuf.Empty;
import java.io.IOException;
import java.util.concurrent.ExecutionException;

class DeleteModel {

  public static void main(String[] args)
      throws IOException, ExecutionException, InterruptedException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String modelId = "YOUR_MODEL_ID";
    deleteModel(projectId, modelId);
  }

  // Delete a model
  static void deleteModel(String projectId, String modelId)
      throws IOException, ExecutionException, InterruptedException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (AutoMlClient client = AutoMlClient.create()) {
      // Get the full path of the model.
      ModelName modelFullId = ModelName.of(projectId, "us-central1", modelId);

      // Delete a model.
      Empty response = client.deleteModelAsync(modelFullId).get();

      System.out.println("Model deletion started...");
      System.out.println(String.format("Model deleted. %s", response));
    }
  }
}

Node.js

如需了解详情,请参阅 AutoML Vision Node.js API 参考文档

如需向 AutoML Vision 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证

/**
 * TODO(developer): Uncomment these variables before running the sample.
 */
// const projectId = 'YOUR_PROJECT_ID';
// const location = 'us-central1';
// const modelId = 'YOUR_MODEL_ID';

// Imports the Google Cloud AutoML library
const {AutoMlClient} = require('@google-cloud/automl').v1;

// Instantiates a client
const client = new AutoMlClient();

async function deleteModel() {
  // Construct request
  const request = {
    name: client.modelPath(projectId, location, modelId),
  };

  const [response] = await client.deleteModel(request);
  console.log(`Model deleted: ${response}`);
}

deleteModel();

请求

使用操作类型 delete_model 发出删除所创建模型的请求。 您必须修改以下代码行:

  • project_id 设置为您的 PROJECT_ID
  • model_id 设置为您的模型 ID

  • python3 delete_model.py {Python}

  • mvn compile exec:java -Dexec.mainClass="com.example.automl.DeleteModel" {Java}

  • node delete_model.js {Node.js}

响应

Model deleted.