Predict

Predict

Explore further

For detailed documentation that includes this code sample, see the following:

Code sample

Go

To learn how to install and use the client library for AutoML Vision Object Detection, see AutoML Vision Object Detection client libraries. For more information, see the AutoML Vision Object Detection Go API reference documentation.

To authenticate to AutoML Vision Object Detection, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.

import (
	"context"
	"fmt"
	"io"
	"io/ioutil"
	"os"

	automl "cloud.google.com/go/automl/apiv1"
	"cloud.google.com/go/automl/apiv1/automlpb"
)

// visionObjectDetectionPredict does a prediction for image classification.
func visionObjectDetectionPredict(w io.Writer, projectID string, location string, modelID string, filePath string) error {
	// projectID := "my-project-id"
	// location := "us-central1"
	// modelID := "IOD123456789..."
	// filePath := "path/to/image.jpg"

	ctx := context.Background()
	client, err := automl.NewPredictionClient(ctx)
	if err != nil {
		return fmt.Errorf("NewPredictionClient: %w", err)
	}
	defer client.Close()

	file, err := os.Open(filePath)
	if err != nil {
		return fmt.Errorf("Open: %w", err)
	}
	defer file.Close()
	bytes, err := ioutil.ReadAll(file)
	if err != nil {
		return fmt.Errorf("ReadAll: %w", err)
	}

	req := &automlpb.PredictRequest{
		Name: fmt.Sprintf("projects/%s/locations/%s/models/%s", projectID, location, modelID),
		Payload: &automlpb.ExamplePayload{
			Payload: &automlpb.ExamplePayload_Image{
				Image: &automlpb.Image{
					Data: &automlpb.Image_ImageBytes{
						ImageBytes: bytes,
					},
				},
			},
		},
		// Params is additional domain-specific parameters.
		Params: map[string]string{
			// score_threshold is used to filter the result.
			"score_threshold": "0.8",
		},
	}

	resp, err := client.Predict(ctx, req)
	if err != nil {
		return fmt.Errorf("Predict: %w", err)
	}

	for _, payload := range resp.GetPayload() {
		fmt.Fprintf(w, "Predicted class name: %v\n", payload.GetDisplayName())
		fmt.Fprintf(w, "Predicted class score: %v\n", payload.GetImageObjectDetection().GetScore())
		boundingBox := payload.GetImageObjectDetection().GetBoundingBox()
		fmt.Fprintf(w, "Normalized vertices:\n")
		for _, vertex := range boundingBox.GetNormalizedVertices() {
			fmt.Fprintf(w, "\tX: %v, Y: %v\n", vertex.GetX(), vertex.GetY())
		}
	}

	return nil
}

Java

To learn how to install and use the client library for AutoML Vision Object Detection, see AutoML Vision Object Detection client libraries. For more information, see the AutoML Vision Object Detection Java API reference documentation.

To authenticate to AutoML Vision Object Detection, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.

import com.google.cloud.automl.v1.AnnotationPayload;
import com.google.cloud.automl.v1.BoundingPoly;
import com.google.cloud.automl.v1.ExamplePayload;
import com.google.cloud.automl.v1.Image;
import com.google.cloud.automl.v1.ModelName;
import com.google.cloud.automl.v1.NormalizedVertex;
import com.google.cloud.automl.v1.PredictRequest;
import com.google.cloud.automl.v1.PredictResponse;
import com.google.cloud.automl.v1.PredictionServiceClient;
import com.google.protobuf.ByteString;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;

class VisionObjectDetectionPredict {

  static void predict() throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String projectId = "YOUR_PROJECT_ID";
    String modelId = "YOUR_MODEL_ID";
    String filePath = "path_to_local_file.jpg";
    predict(projectId, modelId, filePath);
  }

  static void predict(String projectId, String modelId, String filePath) throws IOException {
    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PredictionServiceClient client = PredictionServiceClient.create()) {
      // Get the full path of the model.
      ModelName name = ModelName.of(projectId, "us-central1", modelId);
      ByteString content = ByteString.copyFrom(Files.readAllBytes(Paths.get(filePath)));
      Image image = Image.newBuilder().setImageBytes(content).build();
      ExamplePayload payload = ExamplePayload.newBuilder().setImage(image).build();
      PredictRequest predictRequest =
          PredictRequest.newBuilder()
              .setName(name.toString())
              .setPayload(payload)
              .putParams(
                  "score_threshold", "0.5") // [0.0-1.0] Only produce results higher than this value
              .build();

      PredictResponse response = client.predict(predictRequest);
      for (AnnotationPayload annotationPayload : response.getPayloadList()) {
        System.out.format("Predicted class name: %s\n", annotationPayload.getDisplayName());
        System.out.format(
            "Predicted class score: %.2f\n",
            annotationPayload.getImageObjectDetection().getScore());
        BoundingPoly boundingPoly = annotationPayload.getImageObjectDetection().getBoundingBox();
        System.out.println("Normalized Vertices:");
        for (NormalizedVertex vertex : boundingPoly.getNormalizedVerticesList()) {
          System.out.format("\tX: %.2f, Y: %.2f\n", vertex.getX(), vertex.getY());
        }
      }
    }
  }
}

Node.js

To learn how to install and use the client library for AutoML Vision Object Detection, see AutoML Vision Object Detection client libraries. For more information, see the AutoML Vision Object Detection Node.js API reference documentation.

To authenticate to AutoML Vision Object Detection, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.

/**
 * TODO(developer): Uncomment these variables before running the sample.
 */
// const projectId = 'YOUR_PROJECT_ID';
// const location = 'us-central1';
// const modelId = 'YOUR_MODEL_ID';
// const filePath = 'path_to_local_file.jpg';

// Imports the Google Cloud AutoML library
const {PredictionServiceClient} = require('@google-cloud/automl').v1;
const fs = require('fs');

// Instantiates a client
const client = new PredictionServiceClient();

// Read the file content for translation.
const content = fs.readFileSync(filePath);

async function predict() {
  // Construct request
  // params is additional domain-specific parameters.
  // score_threshold is used to filter the result
  const request = {
    name: client.modelPath(projectId, location, modelId),
    payload: {
      image: {
        imageBytes: content,
      },
    },
    params: {
      score_threshold: '0.8',
    },
  };

  const [response] = await client.predict(request);

  for (const annotationPayload of response.payload) {
    console.log(`Predicted class name: ${annotationPayload.displayName}`);
    console.log(
      `Predicted class score: ${annotationPayload.imageObjectDetection.score}`
    );
    console.log('Normalized vertices:');
    for (const vertex of annotationPayload.imageObjectDetection.boundingBox
      .normalizedVertices) {
      console.log(`\tX: ${vertex.x}, Y: ${vertex.y}`);
    }
  }
}

predict();

Python

To learn how to install and use the client library for AutoML Vision Object Detection, see AutoML Vision Object Detection client libraries. For more information, see the AutoML Vision Object Detection Python API reference documentation.

To authenticate to AutoML Vision Object Detection, set up Application Default Credentials. For more information, see Set up authentication for a local development environment.

from google.cloud import automl

# TODO(developer): Uncomment and set the following variables
# project_id = "YOUR_PROJECT_ID"
# model_id = "YOUR_MODEL_ID"
# file_path = "path_to_local_file.jpg"

prediction_client = automl.PredictionServiceClient()

# Get the full path of the model.
model_full_id = automl.AutoMlClient.model_path(project_id, "us-central1", model_id)

# Read the file.
with open(file_path, "rb") as content_file:
    content = content_file.read()

image = automl.Image(image_bytes=content)
payload = automl.ExamplePayload(image=image)

# params is additional domain-specific parameters.
# score_threshold is used to filter the result
# https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#predictrequest
params = {"score_threshold": "0.8"}

request = automl.PredictRequest(name=model_full_id, payload=payload, params=params)

response = prediction_client.predict(request=request)
print("Prediction results:")
for result in response.payload:
    print(f"Predicted class name: {result.display_name}")
    print(f"Predicted class score: {result.image_object_detection.score}")
    bounding_box = result.image_object_detection.bounding_box
    print("Normalized Vertices:")
    for vertex in bounding_box.normalized_vertices:
        print(f"\tX: {vertex.x}, Y: {vertex.y}")

What's next

To search and filter code samples for other Google Cloud products, see the Google Cloud sample browser.