AutoML Vision API Tutorial

This tutorial demonstrates how to create a new model with your own set of training images, evaluate the results and predict the classification of test image using AutoML Vision.

The tutorial uses a dataset with images of five different kinds of flowers: sunflowers, tulips, daisy, roses and dandelions. It covers training a custom model, evaluating model performance, and classifying new images using the custom model.

Prerequisites

Configure your project environment

  1. Sign in to your Google Account.

    If you don't already have one, sign up for a new account.

  2. In the GCP Console, go to the Manage resources page and select or create a new project.

    Go to the Manage resources page

  3. Make sure that billing is enabled for your project.

    Learn how to enable billing

  4. Enable the AutoML Vision APIs.

    Enable the APIs

  5. Install the gcloud command line tool.
  6. Follow the instructions to create a service account and download a key file.
  7. Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the path to the service account key file that you downloaded when you created the service account. For example:
         export GOOGLE_APPLICATION_CREDENTIALS=key-file
  8. Add your new service account to the AutoML Editor IAM role with the following commands. Replace project-id with the name of your GCP project and replace service-account-name with the name of your new service account, for example service-account1@myproject.iam.gserviceaccount.com.
         gcloud auth login
         gcloud projects add-iam-policy-binding project-id \
           --member="user:your-userid@your-domain" \
           --role="roles/automl.admin"
    gcloud projects add-iam-policy-binding project-id \ --member=serviceAccount:service-account-name \ --role="roles/automl.editor"
  9. Allow the AutoML Vision service accounts to access your Google Cloud project resources:
    gcloud projects add-iam-policy-binding project-id \
      --member="serviceAccount:custom-vision@appspot.gserviceaccount.com" \
      --role="roles/storage.admin"
  10. Install the client library.
  11. Set the PROJECT_ID and REGION_NAME environment variables.

    Replace project-id with the Project ID of your Google Cloud Platform project. AutoML Vision currently requires the location us-central1.

    export PROJECT_ID="project-id"
    export REGION_NAME="us-central1"
    

  12. Create a Google Cloud Storage bucket to store the documents that you will use to train your custom model.

    The bucket name must be in the format: $PROJECT_ID-vcm. The following command creates a storage bucket in the us-central1 region named $PROJECT_ID-vcm.
    gsutil mb -p $PROJECT_ID -c regional -l $REGION_NAME gs://$PROJECT_ID-vcm/
  13. Copy the publicly available dataset of flower images from gs://cloud-ml-data/img/flower_photos/ into your Google Cloud Storage bucket.

    In your Cloud Shell session, enter:
      gsutil -m cp -R gs://cloud-ml-data/img/flower_photos/  gs://$PROJECT_ID-vcm/img/
      
    The file copying takes about 20 minutes to complete.
  14. The sample dataset contains a .csv file with the location and labels for each image. (See Preparing your training data for details about the required format.) Update the .csv file to point to the files in your own bucket:
       gsutil cat gs://$PROJECT_ID-vcm/img/flower_photos/all_data.csv | sed "s:cloud-ml-data:$PROJECT_ID-vcm:" > all_data.csv
       
    Then copy the updated .csv file into your bucket:
       gsutil cp all_data.csv gs://$PROJECT_ID-vcm/csv/
       

Copy the source code files into your Google Cloud Platform project folder

Python

The tutorial consists of these Python files:

  • automl_vision_dataset.py – Includes functionality related to datasets
  • automl_vision_model.py – Includes functionality related to models
  • automl_vision_predict.py – Includes functionality related to predictions

Java

The tutorial consists of these Java files:

Node.js

The tutorial consists of these Node.js programs:

  • automlVisionDataset.js – Includes functionality related to datasets
  • automlVisionModel.js – Includes functionality related to models
  • automlVisionPredict.js – Includes functionality related to predictions

Running the application

Step 1: Create the Flowers dataset

The first step in creating a custom model is to create an empty dataset that will eventually hold the training data for the model. When you create a dataset, you specify the type of classification you want your custom model to perform:

  • MULTICLASS assigns a single label to each classified image
  • MULTILABEL allows an image to be assigned multiple labels

This tutorial creates a dataset named flowers and uses MULTICLASS.

Request

Run the create_dataset function to create an empty dataset. The first parameter gives a name for the dataset (flowers) and the second parameter specifies whether to make the dataset MULTILABEL.

Python

python automl_vision_dataset.py create_dataset "flowers" "False"

Java

mvn compile exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.DatasetApi" -Dexec.args="create_dataset flowers false"

Node.js

node automlVisionDataset.js create-dataset "flowers" "False"

Code

Python

# TODO(developer): Uncomment and set the following variables
# project_id = 'PROJECT_ID_HERE'
# compute_region = 'COMPUTE_REGION_HERE'
# dataset_name = 'DATASET_NAME_HERE'
# multilabel = True for multilabel or False for multiclass

from google.cloud import automl_v1beta1 as automl

client = automl.AutoMlClient()

# A resource that represents Google Cloud Platform location.
project_location = client.location_path(project_id, compute_region)

# Classification type is assigned based on multilabel value.
classification_type = "MULTICLASS"
if multilabel:
    classification_type = "MULTILABEL"

# Specify the image classification type for the dataset.
dataset_metadata = {"classification_type": classification_type}
# Set dataset name and metadata of the dataset.
my_dataset = {
    "display_name": dataset_name,
    "image_classification_dataset_metadata": dataset_metadata,
}

# Create a dataset with the dataset metadata in the region.
dataset = client.create_dataset(project_location, my_dataset)

# Display the dataset information.
print("Dataset name: {}".format(dataset.name))
print("Dataset id: {}".format(dataset.name.split("/")[-1]))
print("Dataset display name: {}".format(dataset.display_name))
print("Image classification dataset metadata:")
print("\t{}".format(dataset.image_classification_dataset_metadata))
print("Dataset example count: {}".format(dataset.example_count))
print("Dataset create time:")
print("\tseconds: {}".format(dataset.create_time.seconds))
print("\tnanos: {}".format(dataset.create_time.nanos))

Java

/**
 * Demonstrates using the AutoML client to create a dataset
 *
 * @param projectId the Google Cloud Project ID.
 * @param computeRegion the Region name. (e.g., "us-central1")
 * @param datasetName the name of the dataset to be created.
 * @param multiLabel the type of classification problem. Set to FALSE by default. False -
 *     MULTICLASS , True - MULTILABEL
 * @throws IOException on Input/Output errors.
 */
public static void createDataset(
    String projectId, String computeRegion, String datasetName, Boolean multiLabel)
    throws IOException {
  // Instantiates a client
  AutoMlClient client = AutoMlClient.create();

  // A resource that represents Google Cloud Platform location.
  LocationName projectLocation = LocationName.of(projectId, computeRegion);

  // Classification type assigned based on multiLabel value.
  ClassificationType classificationType =
      multiLabel ? ClassificationType.MULTILABEL : ClassificationType.MULTICLASS;

  // Specify the image classification type for the dataset.
  ImageClassificationDatasetMetadata imageClassificationDatasetMetadata =
      ImageClassificationDatasetMetadata.newBuilder()
          .setClassificationType(classificationType)
          .build();

  // Set dataset with dataset name and set the dataset metadata.
  Dataset myDataset =
      Dataset.newBuilder()
          .setDisplayName(datasetName)
          .setImageClassificationDatasetMetadata(imageClassificationDatasetMetadata)
          .build();

  // Create dataset with the dataset metadata in the region.
  Dataset dataset = client.createDataset(projectLocation, myDataset);

  // Display the dataset information
  System.out.println(String.format("Dataset name: %s", dataset.getName()));
  System.out.println(
      String.format(
          "Dataset id: %s",
          dataset.getName().split("/")[dataset.getName().split("/").length - 1]));
  System.out.println(String.format("Dataset display name: %s", dataset.getDisplayName()));
  System.out.println("Image classification dataset specification:");
  System.out.print(String.format("\t%s", dataset.getImageClassificationDatasetMetadata()));
  System.out.println(String.format("Dataset example count: %d", dataset.getExampleCount()));
  System.out.println("Dataset create time:");
  System.out.println(String.format("\tseconds: %s", dataset.getCreateTime().getSeconds()));
  System.out.println(String.format("\tnanos: %s", dataset.getCreateTime().getNanos()));
}

Node.js

  const automl = require(`@google-cloud/automl`).v1beta1;

  const client = new automl.AutoMlClient();

  /**
   * TODO(developer): Uncomment the following line before running the sample.
   */
  // const projectId = `The GCLOUD_PROJECT string, e.g. "my-gcloud-project"`;
  // const computeRegion = `region-name, e.g. "us-central1"`;
  // const datasetName = `name of the dataset to create, e.g. “myDataset”`;
  // const multiLabel = `type of classification problem, true for multilabel and false for multiclass e.g. "false"`;

  // A resource that represents Google Cloud Platform location.
  const projectLocation = client.locationPath(projectId, computeRegion);

  // Classification type is assigned based on multilabel value.
  let classificationType = `MULTICLASS`;
  if (multiLabel) {
    classificationType = `MULTILABEL`;
  }

  // Specify the text classification type for the dataset.
  const datasetMetadata = {
    classificationType: classificationType,
  };

  // Set dataset name and metadata.
  const myDataset = {
    displayName: datasetName,
    imageClassificationDatasetMetadata: datasetMetadata,
  };

  // Create a dataset with the dataset metadata in the region.
  client
    .createDataset({parent: projectLocation, dataset: myDataset})
    .then(responses => {
      const dataset = responses[0];

      // Display the dataset information.
      console.log(`Dataset name: ${dataset.name}`);
      console.log(`Dataset id: ${dataset.name.split(`/`).pop(-1)}`);
      console.log(`Dataset display name: ${dataset.displayName}`);
      console.log(`Dataset example count: ${dataset.exampleCount}`);
      console.log(`Image Classification type:`);
      console.log(
        `\t ${dataset.imageClassificationDatasetMetadata.classificationType}`
      );
      console.log(`Dataset create time:`);
      console.log(`\tseconds: ${dataset.createTime.seconds}`);
      console.log(`\tnanos: ${dataset.createTime.nanos}`);
    })
    .catch(err => {
      console.error(err);
    });

Response

The response includes the details of the newly created dataset, including the Dataset ID that you'll use to reference the dataset in future requests. We recommend that you set an environment variable DATASET_ID to the returned Dataset ID value.

Dataset name: projects/216065747626/locations/us-central1/datasets/ICN7372141011130533778
Dataset id: ICN7372141011130533778
Dataset display name: flowers
Image classification dataset specification:
       classification_type: MULTICLASS
Dataset example count: 0
Dataset create time:
       seconds: 1530251987
       nanos: 216586000

Step 2: Import images into the dataset

The next step is to populate the dataset with training images labeled using the target labels.

The import_data function interface takes as input a .csv file that lists the locations of all training images and the proper label for each one. (See Prepare your data for details about the required format.) For this tutorial, we will use the labeled images that you copied into your Google Cloud Storage bucket, which are listed in gs://$PROJECT_ID-vcm/csv/all_data.csv.

Request

Run the import_data function to import the training content. The first parameter is the Dataset ID from the previous step and the second parameter is the URI of all_data.csv.

  • python automl_vision_dataset.py import_data $DATASET_ID "gs://$PROJECT_ID-vcm/csv/all_data.csv" {Python}

  • mvn compile exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.DatasetApi" -Dexec.args="import_data $DATASET_ID gs://$PROJECT_ID-vcm/csv/all_data.csv" {Java}

  • node automlVisionDataset.js import-data $DATASET_ID "gs://$PROJECT_ID-vcm/csv/all_data.csv" {Node.js}

Code

Python

# TODO(developer): Uncomment and set the following variables
# project_id = 'PROJECT_ID_HERE'
# compute_region = 'COMPUTE_REGION_HERE'
# dataset_id = 'DATASET_ID_HERE'
# path = 'gs://path/to/file.csv'

from google.cloud import automl_v1beta1 as automl

client = automl.AutoMlClient()

# Get the full path of the dataset.
dataset_full_id = client.dataset_path(
    project_id, compute_region, dataset_id
)

# Get the multiple Google Cloud Storage URIs.
input_uris = path.split(",")
input_config = {"gcs_source": {"input_uris": input_uris}}

# Import data from the input URI.
response = client.import_data(dataset_full_id, input_config)

print("Processing import...")
# synchronous check of operation status.
print("Data imported. {}".format(response.result()))

Java

/**
 * Demonstrates using the AutoML client to import labeled images.
 *
 * @param projectId the Id of the project.
 * @param computeRegion the Region name.
 * @param datasetId the Id of the dataset to which the training data will be imported.
 * @param path the Google Cloud Storage URIs. Target files must be in AutoML vision CSV format.
 * @throws Exception on AutoML Client errors
 */
public static void importData(
    String projectId, String computeRegion, String datasetId, String path) throws Exception {
  // Instantiates a client
  AutoMlClient client = AutoMlClient.create();

  // Get the complete path of the dataset.
  DatasetName datasetFullId = DatasetName.of(projectId, computeRegion, datasetId);

  GcsSource.Builder gcsSource = GcsSource.newBuilder();

  // Get multiple training data files to be imported
  String[] inputUris = path.split(",");
  for (String inputUri : inputUris) {
    gcsSource.addInputUris(inputUri);
  }

  // Import data from the input URI
  InputConfig inputConfig = InputConfig.newBuilder().setGcsSource(gcsSource).build();
  System.out.println("Processing import...");
  Empty response = client.importDataAsync(datasetFullId.toString(), inputConfig).get();
  System.out.println(String.format("Dataset imported. %s", response));
}

Node.js

  const automl = require(`@google-cloud/automl`).v1beta1;

  const client = new automl.AutoMlClient();

  /**
   * TODO(developer): Uncomment the following line before running the sample.
   */
  // const projectId = `The GCLOUD_PROJECT string, e.g. "my-gcloud-project"`;
  // const computeRegion = `region-name, e.g. "us-central1"`;
  // const datasetId = `Id of the dataset`;
  // const path = `string or array of .csv paths in AutoML Vision CSV format, e.g. “gs://myproject/traindata.csv”;`

  // Get the full path of the dataset.
  const datasetFullId = client.datasetPath(projectId, computeRegion, datasetId);

  // Get one or more Google Cloud Storage URI(s).
  const inputUris = path.split(`,`);
  const inputConfig = {
    gcsSource: {
      inputUris: inputUris,
    },
  };

  // Import the dataset from the input URI.
  client
    .importData({name: datasetFullId, inputConfig: inputConfig})
    .then(responses => {
      const operation = responses[0];
      console.log(`Processing import...`);
      return operation.promise();
    })
    .then(responses => {
      // The final result of the operation.
      if (responses[2].done) {
        console.log(`Data imported.`);
      }
    })
    .catch(err => {
      console.error(err);
    });

Response

Processing import...
Dataset imported.

Step 3: Create (train) the model

Now that you have a dataset of labeled training images, you can train a new model.

Request

The first parameter for the create_model function is the Dataset ID from the previous steps, the second parameter is a name for the new model, and the third is the training budget. The training budget is the number of hours of training to use for the model. AutoML Vision guarantees that the actual training time will be less than or equal to the training budget.

A training hour represents internal compute usage, and therefore does not exactly match an actual hour on the clock.

  • python automl_vision_model.py create_model $DATASET_ID "flowers_model" "1" {Python}

  • mvn compile exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.ModelApi" -Dexec.args="create_model $DATASET_ID flowers_model 1" {Java}

  • node automlVisionModel.js create-model $DATASET_ID "flowers_model" "1" {Node.js}

Code

Python

# TODO(developer): Uncomment and set the following variables
# project_id = 'PROJECT_ID_HERE'
# compute_region = 'COMPUTE_REGION_HERE'
# dataset_id = 'DATASET_ID_HERE'
# model_name = 'MODEL_NAME_HERE'
# train_budget = integer amount for maximum cost of model

from google.cloud import automl_v1beta1 as automl

client = automl.AutoMlClient()

# A resource that represents Google Cloud Platform location.
project_location = client.location_path(project_id, compute_region)

# Set model name and model metadata for the image dataset.
my_model = {
    "display_name": model_name,
    "dataset_id": dataset_id,
    "image_classification_model_metadata": {"train_budget": train_budget}
    if train_budget
    else {},
}

# Create a model with the model metadata in the region.
response = client.create_model(project_location, my_model)

print("Training operation name: {}".format(response.operation.name))
print("Training started...")

Java

/**
 * Demonstrates using the AutoML client to create a model.
 *
 * @param projectId the Id of the project.
 * @param computeRegion the Region name.
 * @param dataSetId the Id of the dataset to which model is created.
 * @param modelName the Name of the model.
 * @param trainBudget the Budget for training the model.
 * @throws Exception on AutoML Client errors
 */
public static void createModel(
    String projectId,
    String computeRegion,
    String dataSetId,
    String modelName,
    String trainBudget)
    throws Exception {
  // Instantiates a client
  AutoMlClient client = AutoMlClient.create();

  // A resource that represents Google Cloud Platform location.
  LocationName projectLocation = LocationName.of(projectId, computeRegion);

  // Set model metadata.
  ImageClassificationModelMetadata imageClassificationModelMetadata =
      Long.valueOf(trainBudget) == 0
          ? ImageClassificationModelMetadata.newBuilder().build()
          : ImageClassificationModelMetadata.newBuilder()
              .setTrainBudget(Long.valueOf(trainBudget))
              .build();

  // Set model name and model metadata for the image dataset.
  Model myModel =
      Model.newBuilder()
          .setDisplayName(modelName)
          .setDatasetId(dataSetId)
          .setImageClassificationModelMetadata(imageClassificationModelMetadata)
          .build();

  // Create a model with the model metadata in the region.
  OperationFuture<Model, OperationMetadata> response =
      client.createModelAsync(projectLocation, myModel);

  System.out.println(
      String.format("Training operation name: %s", response.getInitialFuture().get().getName()));
  System.out.println("Training started...");
}

Node.js

  const automl = require(`@google-cloud/automl`).v1beta1;

  const client = new automl.AutoMlClient();

  /**
   * TODO(developer): Uncomment the following line before running the sample.
   */
  // const projectId = `The GCLOUD_PROJECT string, e.g. "my-gcloud-project"`;
  // const computeRegion = `region-name, e.g. "us-central1"`;
  // const datasetId = `Id of the dataset`;
  // const modelName = `Name of the model, e.g. "myModel"`;
  // const trainBudget = `Budget for training model, e.g. 50`;

  // A resource that represents Google Cloud Platform location.
  const projectLocation = client.locationPath(projectId, computeRegion);

  // Check train budget condition.
  if (trainBudget === 0) {
    trainBudget = {};
  } else {
    trainBudget = {trainBudget: trainBudget};
  }

  // Set model name and model metadata for the dataset.
  const myModel = {
    displayName: modelName,
    datasetId: datasetId,
    imageClassificationModelMetadata: trainBudget,
  };

  // Create a model with the model metadata in the region.
  client
    .createModel({parent: projectLocation, model: myModel})
    .then(responses => {
      const operation = responses[0];
      const initialApiResponse = responses[1];

      console.log(`Training operation name: `, initialApiResponse.name);
      console.log(`Training started...`);
      return operation.promise();
    })
    .then(responses => {
      // The final result of the operation.
      const model = responses[0];

      // Retrieve deployment state.
      let deploymentState = ``;
      if (model.deploymentState === 1) {
        deploymentState = `deployed`;
      } else if (model.deploymentState === 2) {
        deploymentState = `undeployed`;
      }

      // Display the model information.
      console.log(`Model name: ${model.name}`);
      console.log(`Model id: ${model.name.split(`/`).pop(-1)}`);
      console.log(`Model display name: ${model.displayName}`);
      console.log(`Model create time:`);
      console.log(`\tseconds: ${model.createTime.seconds}`);
      console.log(`\tnanos: ${model.createTime.nanos}`);
      console.log(`Model deployment state: ${deploymentState}`);
    })
    .catch(err => {
      console.error(err);
    });

Response

The create_model function kicks off a training operation and prints the operation name. Training happens asynchronously and can take a while to complete, so you can use the operation ID to check training status. When training is complete, create_model returns the Model ID. As with the Dataset ID, you might want to set an environment variable MODEL_ID to the returned Model ID value.

Training operation name: projects/216065747626/locations/us-central1/operations/ICN3007727620979824033
Training started...
Model name: projects/216065747626/locations/us-central1/models/ICN7683346839371803263
Model id: ICN7683346839371803263
Model display name: flowers_model
Image classification model metadata:
Training budget: 1
Training cost: 1
Stop reason:
Base model id:
Model create time:
        seconds: 1529649600
        nanos: 966000000
Model deployment state: deployed

Step 4: Evaluate the model

After training, you can evaluate your model's readiness by reviewing its precision, recall, and F1 score.

The display_evaluation function takes the Model ID as a parameter.

Request

Make a request to display the overall evaluation performance of the model by executing the following request with operation type display_evaluation. Pass the Model ID and filter (optional) as arguments.

  • python automl_vision_model.py display_evaluation $MODEL_ID {Python}

  • mvn compile exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.ModelApi" -Dexec.args="display_evaluation $MODEL_ID" {Java}

  • node automlVisionModel.js display-evaluation $MODEL_ID {Node.js}

Code

Python

# TODO(developer): Uncomment and set the following variables
# project_id = 'PROJECT_ID_HERE'
# compute_region = 'COMPUTE_REGION_HERE'
# model_id = 'MODEL_ID_HERE'
# filter_ = 'filter expression here'

from google.cloud import automl_v1beta1 as automl

client = automl.AutoMlClient()

# Get the full path of the model.
model_full_id = client.model_path(project_id, compute_region, model_id)

# List all the model evaluations in the model by applying filter.
response = client.list_model_evaluations(model_full_id, filter_)

# Iterate through the results.
for element in response:
    # There is evaluation for each class in a model and for overall model.
    # Get only the evaluation of overall model.
    if not element.annotation_spec_id:
        model_evaluation_id = element.name.split("/")[-1]

# Resource name for the model evaluation.
model_evaluation_full_id = client.model_evaluation_path(
    project_id, compute_region, model_id, model_evaluation_id
)

# Get a model evaluation.
model_evaluation = client.get_model_evaluation(model_evaluation_full_id)

class_metrics = model_evaluation.classification_evaluation_metrics
confidence_metrics_entries = class_metrics.confidence_metrics_entry

# Showing model score based on threshold of 0.5
for confidence_metrics_entry in confidence_metrics_entries:
    if confidence_metrics_entry.confidence_threshold == 0.5:
        print("Precision and recall are based on a score threshold of 0.5")
        print(
            "Model Precision: {}%".format(
                round(confidence_metrics_entry.precision * 100, 2)
            )
        )
        print(
            "Model Recall: {}%".format(
                round(confidence_metrics_entry.recall * 100, 2)
            )
        )
        print(
            "Model F1 score: {}%".format(
                round(confidence_metrics_entry.f1_score * 100, 2)
            )
        )
        print(
            "Model Precision@1: {}%".format(
                round(confidence_metrics_entry.precision_at1 * 100, 2)
            )
        )
        print(
            "Model Recall@1: {}%".format(
                round(confidence_metrics_entry.recall_at1 * 100, 2)
            )
        )
        print(
            "Model F1 score@1: {}%".format(
                round(confidence_metrics_entry.f1_score_at1 * 100, 2)
            )
        )

Java

/**
 * Demonstrates using the AutoML client to display model evaluation.
 *
 * @param projectId the Id of the project.
 * @param computeRegion the Region name.
 * @param modelId the Id of the model.
 * @param filter the filter expression.
 * @throws IOException on Input/Output errors.
 */
public static void displayEvaluation(
    String projectId, String computeRegion, String modelId, String filter) throws IOException {
  AutoMlClient client = AutoMlClient.create();

  // Get the full path of the model.
  ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId);

  // List all the model evaluations in the model by applying filter.
  ListModelEvaluationsRequest modelEvaluationsrequest =
      ListModelEvaluationsRequest.newBuilder()
          .setParent(modelFullId.toString())
          .setFilter(filter)
          .build();

  // Iterate through the results.
  String modelEvaluationId = "";
  for (ModelEvaluation element :
      client.listModelEvaluations(modelEvaluationsrequest).iterateAll()) {
    if (element.getAnnotationSpecId() != null) {
      modelEvaluationId = element.getName().split("/")[element.getName().split("/").length - 1];
    }
  }

  // Resource name for the model evaluation.
  ModelEvaluationName modelEvaluationFullId =
      ModelEvaluationName.of(projectId, computeRegion, modelId, modelEvaluationId);

  // Get a model evaluation.
  ModelEvaluation modelEvaluation = client.getModelEvaluation(modelEvaluationFullId);

  ClassificationEvaluationMetrics classMetrics =
      modelEvaluation.getClassificationEvaluationMetrics();
  List<ConfidenceMetricsEntry> confidenceMetricsEntries =
      classMetrics.getConfidenceMetricsEntryList();

  // Showing model score based on threshold of 0.5
  for (ConfidenceMetricsEntry confidenceMetricsEntry : confidenceMetricsEntries) {
    if (confidenceMetricsEntry.getConfidenceThreshold() == 0.5) {
      System.out.println("Precision and recall are based on a score threshold of 0.5");
      System.out.println(
          String.format("Model Precision: %.2f ", confidenceMetricsEntry.getPrecision() * 100)
              + '%');
      System.out.println(
          String.format("Model Recall: %.2f ", confidenceMetricsEntry.getRecall() * 100) + '%');
      System.out.println(
          String.format("Model F1 score: %.2f ", confidenceMetricsEntry.getF1Score() * 100)
              + '%');
      System.out.println(
          String.format(
                  "Model Precision@1: %.2f ", confidenceMetricsEntry.getPrecisionAt1() * 100)
              + '%');
      System.out.println(
          String.format("Model Recall@1: %.2f ", confidenceMetricsEntry.getRecallAt1() * 100)
              + '%');
      System.out.println(
          String.format("Model F1 score@1: %.2f ", confidenceMetricsEntry.getF1ScoreAt1() * 100)
              + '%');
    }
  }
}

Node.js

  const automl = require(`@google-cloud/automl`).v1beta1;
  const math = require(`mathjs`);

  const client = new automl.AutoMlClient();

  /**
   * TODO(developer): Uncomment the following line before running the sample.
   */
  // const projectId = `The GCLOUD_PROJECT string, e.g. "my-gcloud-project"`;
  // const computeRegion = `region-name, e.g. "us-central1"`;
  // const modelId = `id of the model, e.g. “ICN12345”`;
  // const filter = `filter expressions, must specify field, e.g. “imageClassificationModelMetadata:*”`;

  // Get the full path of the model.
  const modelFullId = client.modelPath(projectId, computeRegion, modelId);

  // List all the model evaluations in the model by applying filter.
  client
    .listModelEvaluations({parent: modelFullId, filter: filter})
    .then(respond => {
      const response = respond[0];
      response.forEach(element => {
        // There is evaluation for each class in a model and for overall model.
        // Get only the evaluation of overall model.
        if (!element.annotationSpecId) {
          const modelEvaluationId = element.name.split(`/`).pop(-1);

          // Resource name for the model evaluation.
          const modelEvaluationFullId = client.modelEvaluationPath(
            projectId,
            computeRegion,
            modelId,
            modelEvaluationId
          );

          // Get a model evaluation.
          client
            .getModelEvaluation({name: modelEvaluationFullId})
            .then(responses => {
              const modelEvaluation = responses[0];

              const classMetrics =
                modelEvaluation.classificationEvaluationMetrics;

              const confidenceMetricsEntries =
                classMetrics.confidenceMetricsEntry;

              // Showing model score based on threshold of 0.5
              confidenceMetricsEntries.forEach(confidenceMetricsEntry => {
                if (confidenceMetricsEntry.confidenceThreshold === 0.5) {
                  console.log(
                    `Precision and recall are based on a score threshold of 0.5`
                  );
                  console.log(
                    `Model Precision: %`,
                    math.round(confidenceMetricsEntry.precision * 100, 2)
                  );
                  console.log(
                    `Model Recall: %`,
                    math.round(confidenceMetricsEntry.recall * 100, 2)
                  );
                  console.log(
                    `Model F1 score: %`,
                    math.round(confidenceMetricsEntry.f1Score * 100, 2)
                  );
                  console.log(
                    `Model Precision@1: %`,
                    math.round(confidenceMetricsEntry.precisionAt1 * 100, 2)
                  );
                  console.log(
                    `Model Recall@1: %`,
                    math.round(confidenceMetricsEntry.recallAt1 * 100, 2)
                  );
                  console.log(
                    `Model F1 score@1: %`,
                    math.round(confidenceMetricsEntry.f1ScoreAt1 * 100, 2)
                  );
                }
              });
            })
            .catch(err => {
              console.error(err);
            });
        }
      });
    })
    .catch(err => {
      console.error(err);
    });

Response

If the precision and recall scores are too low, you can strengthen the training dataset and re-train your model. For more information, see Evaluating models.

Precision and recall are based on a score threshold of 0.5
Model Precision: 96.3%
Model Recall: 95.7%
Model F1 score: 96.0%
Model Precision@1: 96.33%
Model Recall@1: 95.74%
Model F1 score@1: 96.04%

Step 5: Use a model to make a prediction

When your custom model meets your quality standards, you can use it to classify new flower images.

Request

The predict function takes as parameters the Model ID, the URI of the image to classify, and a confidence score threshold. This tutorial uses 0.7 as the confidence score threshold; it only returns results that have a score of at least 0.7.

  • python automl_vision_predict.py predict $MODEL_ID "resources/test.png" "0.7" {Python}

  • mvn compile exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.PredictionApi" -Dexec.args="predict $MODEL_ID resources/test.png 0.7" {Java}

  • node automlVisionPredict.js predict $MODEL_ID "resources/test.png" "0.7" {Node.js}

Code

Python

# TODO(developer): Uncomment and set the following variables
# project_id = 'PROJECT_ID_HERE'
# compute_region = 'COMPUTE_REGION_HERE'
# model_id = 'MODEL_ID_HERE'
# file_path = '/local/path/to/file'
# score_threshold = 'value from 0.0 to 0.5'

from google.cloud import automl_v1beta1 as automl

automl_client = automl.AutoMlClient()

# Get the full path of the model.
model_full_id = automl_client.model_path(
    project_id, compute_region, model_id
)

# Create client for prediction service.
prediction_client = automl.PredictionServiceClient()

# Read the image and assign to payload.
with open(file_path, "rb") as image_file:
    content = image_file.read()
payload = {"image": {"image_bytes": content}}

# params is additional domain-specific parameters.
# score_threshold is used to filter the result
# Initialize params
params = {}
if score_threshold:
    params = {"score_threshold": score_threshold}

response = prediction_client.predict(model_full_id, payload, params)
print("Prediction results:")
for result in response.payload:
    print("Predicted class name: {}".format(result.display_name))
    print("Predicted class score: {}".format(result.classification.score))

Java

/**
 * Demonstrates using the AutoML client to predict an image.
 *
 * @param projectId the Id of the project.
 * @param computeRegion the Region name.
 * @param modelId the Id of the model which will be used for text classification.
 * @param filePath the Local text file path of the content to be classified.
 * @param scoreThreshold the Confidence score. Only classifications with confidence score above
 *     scoreThreshold are displayed.
 * @throws IOException on Input/Output errors.
 */
public static void predict(
    String projectId,
    String computeRegion,
    String modelId,
    String filePath,
    String scoreThreshold)
    throws IOException {

  // Instantiate client for prediction service.
  PredictionServiceClient predictionClient = PredictionServiceClient.create();

  // Get the full path of the model.
  ModelName name = ModelName.of(projectId, computeRegion, modelId);

  // Read the image and assign to payload.
  ByteString content = ByteString.copyFrom(Files.readAllBytes(Paths.get(filePath)));
  Image image = Image.newBuilder().setImageBytes(content).build();
  ExamplePayload examplePayload = ExamplePayload.newBuilder().setImage(image).build();

  // Additional parameters that can be provided for prediction e.g. Score Threshold
  Map<String, String> params = new HashMap<>();
  if (scoreThreshold != null) {
    params.put("score_threshold", scoreThreshold);
  }
  // Perform the AutoML Prediction request
  PredictResponse response = predictionClient.predict(name, examplePayload, params);

  System.out.println("Prediction results:");
  for (AnnotationPayload annotationPayload : response.getPayloadList()) {
    System.out.println("Predicted class name :" + annotationPayload.getDisplayName());
    System.out.println(
        "Predicted class score :" + annotationPayload.getClassification().getScore());
  }
}

Node.js

  const automl = require('@google-cloud/automl').v1beta1;
  const fs = require('fs');

  // Create client for prediction service.
  const client = new automl.PredictionServiceClient();

  /**
   * TODO(developer): Uncomment the following line before running the sample.
   */
  // const projectId = `The GCLOUD_PROJECT string, e.g. "my-gcloud-project"`;
  // const computeRegion = `region-name, e.g. "us-central1"`;
  // const modelId = `id of the model, e.g. “ICN12345”`;
  // const filePath = `local text file path of content to be classified, e.g. "./resources/test.txt"`;
  // const scoreThreshold = `value between 0.0 and 1.0, e.g. "0.5"';

  // Get the full path of the model.
  const modelFullId = client.modelPath(projectId, computeRegion, modelId);

  // Read the file content for prediction.
  const content = fs.readFileSync(filePath, 'base64');

  const params = {};

  if (scoreThreshold) {
    params.scoreThreshold = scoreThreshold;
  }

  // Set the payload by giving the content and type of the file.
  const payload = {};
  payload.image = {imageBytes: content};

  // params is additional domain-specific parameters.
  // currently there is no additional parameters supported.
  client
    .predict({name: modelFullId, payload: payload, params: params})
    .then(responses => {
      console.log(`Prediction results:`);
      responses[0].payload.forEach(result => {
        console.log(`Predicted class name: ${result.displayName}`);
        console.log(`Predicted class score: ${result.classification.score}`);
      });
    })
    .catch(err => {
      console.error(err);
    });

Response

The function returns the classification score for how well the image matches each category, exceeding the stated confidence threshold of 0.7.

Prediction results:
Predicted class name: dandelion
Predicted class score: 0.9702693223953247

Step 6: Delete the model

When you are done using this sample model, you can delete it permanently. You will no longer be able to use the model for prediction.

Request

Make a request with operation type delete_model to delete a model you created Pass the Model ID as an argument.

  • python automl_vision_model.py delete_model $MODEL_ID {Python}

  • mvn compile exec:java -Dexec.mainClass="com.google.cloud.vision.samples.automl.ModelApi" -Dexec.args="delete_model $MODEL_ID" {Java}

  • node automlVisionModel.js delete-model $MODEL_ID {Node.js}

Code

Python

# TODO(developer): Uncomment and set the following variables
# project_id = 'PROJECT_ID_HERE'
# compute_region = 'COMPUTE_REGION_HERE'
# model_id = 'MODEL_ID_HERE'

from google.cloud import automl_v1beta1 as automl

client = automl.AutoMlClient()

# Get the full path of the model.
model_full_id = client.model_path(project_id, compute_region, model_id)

# Delete a model.
response = client.delete_model(model_full_id)

# synchronous check of operation status.
print("Model deleted. {}".format(response.result()))

Java

/**
 * Demonstrates using the AutoML client to delete a model.
 *
 * @param projectId the Id of the project.
 * @param computeRegion the Region name.
 * @param modelId the Id of the model.
 * @throws Exception on AutoML Client errors
 */
public static void deleteModel(String projectId, String computeRegion, String modelId)
    throws Exception {
  AutoMlClient client = AutoMlClient.create();

  // Get the full path of the model.
  ModelName modelFullId = ModelName.of(projectId, computeRegion, modelId);

  // Delete a model.
  Empty response = client.deleteModelAsync(modelFullId).get();

  System.out.println("Model deletion started...");
}

Node.js

  const automl = require(`@google-cloud/automl`).v1beta1;

  const client = new automl.AutoMlClient();

  /**
   * TODO(developer): Uncomment the following line before running the sample.
   */
  // const projectId = `The GCLOUD_PROJECT string, e.g. "my-gcloud-project"`;
  // const computeRegion = `region-name, e.g. "us-central1"`;
  // const modelId = `id of the model, e.g. “ICN12345”`;

  // Get the full path of the model.
  const modelFullId = client.modelPath(projectId, computeRegion, modelId);

  // Delete a model.
  client
    .deleteModel({name: modelFullId})
    .then(responses => {
      const operation = responses[0];
      return operation.promise();
    })
    .then(responses => {
      // The final result of the operation.
      if (responses[2].done) {
        console.log(`Model deleted.`);
      }
    })
    .catch(err => {
      console.error(err);
    });

Response

Model deleted.
Was this page helpful? Let us know how we did:

Send feedback about...

Cloud AutoML Vision
Need help? Visit our support page.