예측
더 살펴보기
이 코드 샘플이 포함된 자세한 문서는 다음을 참조하세요.
코드 샘플
Go
import (
"context"
"fmt"
"io"
"io/ioutil"
"os"
automl "cloud.google.com/go/automl/apiv1"
"cloud.google.com/go/automl/apiv1/automlpb"
)
// visionObjectDetectionPredict does a prediction for image classification.
func visionObjectDetectionPredict(w io.Writer, projectID string, location string, modelID string, filePath string) error {
// projectID := "my-project-id"
// location := "us-central1"
// modelID := "IOD123456789..."
// filePath := "path/to/image.jpg"
ctx := context.Background()
client, err := automl.NewPredictionClient(ctx)
if err != nil {
return fmt.Errorf("NewPredictionClient: %v", err)
}
defer client.Close()
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("Open: %v", err)
}
defer file.Close()
bytes, err := ioutil.ReadAll(file)
if err != nil {
return fmt.Errorf("ReadAll: %v", err)
}
req := &automlpb.PredictRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/models/%s", projectID, location, modelID),
Payload: &automlpb.ExamplePayload{
Payload: &automlpb.ExamplePayload_Image{
Image: &automlpb.Image{
Data: &automlpb.Image_ImageBytes{
ImageBytes: bytes,
},
},
},
},
// Params is additional domain-specific parameters.
Params: map[string]string{
// score_threshold is used to filter the result.
"score_threshold": "0.8",
},
}
resp, err := client.Predict(ctx, req)
if err != nil {
return fmt.Errorf("Predict: %v", err)
}
for _, payload := range resp.GetPayload() {
fmt.Fprintf(w, "Predicted class name: %v\n", payload.GetDisplayName())
fmt.Fprintf(w, "Predicted class score: %v\n", payload.GetImageObjectDetection().GetScore())
boundingBox := payload.GetImageObjectDetection().GetBoundingBox()
fmt.Fprintf(w, "Normalized vertices:\n")
for _, vertex := range boundingBox.GetNormalizedVertices() {
fmt.Fprintf(w, "\tX: %v, Y: %v\n", vertex.GetX(), vertex.GetY())
}
}
return nil
}
Java
import com.google.cloud.automl.v1.AnnotationPayload;
import com.google.cloud.automl.v1.BoundingPoly;
import com.google.cloud.automl.v1.ExamplePayload;
import com.google.cloud.automl.v1.Image;
import com.google.cloud.automl.v1.ModelName;
import com.google.cloud.automl.v1.NormalizedVertex;
import com.google.cloud.automl.v1.PredictRequest;
import com.google.cloud.automl.v1.PredictResponse;
import com.google.cloud.automl.v1.PredictionServiceClient;
import com.google.protobuf.ByteString;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
class VisionObjectDetectionPredict {
static void predict() throws IOException {
// TODO(developer): Replace these variables before running the sample.
String projectId = "YOUR_PROJECT_ID";
String modelId = "YOUR_MODEL_ID";
String filePath = "path_to_local_file.jpg";
predict(projectId, modelId, filePath);
}
static void predict(String projectId, String modelId, String filePath) throws IOException {
// Initialize client that will be used to send requests. This client only needs to be created
// once, and can be reused for multiple requests. After completing all of your requests, call
// the "close" method on the client to safely clean up any remaining background resources.
try (PredictionServiceClient client = PredictionServiceClient.create()) {
// Get the full path of the model.
ModelName name = ModelName.of(projectId, "us-central1", modelId);
ByteString content = ByteString.copyFrom(Files.readAllBytes(Paths.get(filePath)));
Image image = Image.newBuilder().setImageBytes(content).build();
ExamplePayload payload = ExamplePayload.newBuilder().setImage(image).build();
PredictRequest predictRequest =
PredictRequest.newBuilder()
.setName(name.toString())
.setPayload(payload)
.putParams(
"score_threshold", "0.5") // [0.0-1.0] Only produce results higher than this value
.build();
PredictResponse response = client.predict(predictRequest);
for (AnnotationPayload annotationPayload : response.getPayloadList()) {
System.out.format("Predicted class name: %s\n", annotationPayload.getDisplayName());
System.out.format(
"Predicted class score: %.2f\n",
annotationPayload.getImageObjectDetection().getScore());
BoundingPoly boundingPoly = annotationPayload.getImageObjectDetection().getBoundingBox();
System.out.println("Normalized Vertices:");
for (NormalizedVertex vertex : boundingPoly.getNormalizedVerticesList()) {
System.out.format("\tX: %.2f, Y: %.2f\n", vertex.getX(), vertex.getY());
}
}
}
}
}
Node.js
/**
* TODO(developer): Uncomment these variables before running the sample.
*/
// const projectId = 'YOUR_PROJECT_ID';
// const location = 'us-central1';
// const modelId = 'YOUR_MODEL_ID';
// const filePath = 'path_to_local_file.jpg';
// Imports the Google Cloud AutoML library
const {PredictionServiceClient} = require('@google-cloud/automl').v1;
const fs = require('fs');
// Instantiates a client
const client = new PredictionServiceClient();
// Read the file content for translation.
const content = fs.readFileSync(filePath);
async function predict() {
// Construct request
// params is additional domain-specific parameters.
// score_threshold is used to filter the result
const request = {
name: client.modelPath(projectId, location, modelId),
payload: {
image: {
imageBytes: content,
},
},
params: {
score_threshold: '0.8',
},
};
const [response] = await client.predict(request);
for (const annotationPayload of response.payload) {
console.log(`Predicted class name: ${annotationPayload.displayName}`);
console.log(
`Predicted class score: ${annotationPayload.imageObjectDetection.score}`
);
console.log('Normalized vertices:');
for (const vertex of annotationPayload.imageObjectDetection.boundingBox
.normalizedVertices) {
console.log(`\tX: ${vertex.x}, Y: ${vertex.y}`);
}
}
}
predict();
Python
from google.cloud import automl
# TODO(developer): Uncomment and set the following variables
# project_id = "YOUR_PROJECT_ID"
# model_id = "YOUR_MODEL_ID"
# file_path = "path_to_local_file.jpg"
prediction_client = automl.PredictionServiceClient()
# Get the full path of the model.
model_full_id = automl.AutoMlClient.model_path(project_id, "us-central1", model_id)
# Read the file.
with open(file_path, "rb") as content_file:
content = content_file.read()
image = automl.Image(image_bytes=content)
payload = automl.ExamplePayload(image=image)
# params is additional domain-specific parameters.
# score_threshold is used to filter the result
# https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#predictrequest
params = {"score_threshold": "0.8"}
request = automl.PredictRequest(name=model_full_id, payload=payload, params=params)
response = prediction_client.predict(request=request)
print("Prediction results:")
for result in response.payload:
print("Predicted class name: {}".format(result.display_name))
print("Predicted class score: {}".format(result.image_object_detection.score))
bounding_box = result.image_object_detection.bounding_box
print("Normalized Vertices:")
for vertex in bounding_box.normalized_vertices:
print("\tX: {}, Y: {}".format(vertex.x, vertex.y))
다음 단계
다른 Google Cloud 제품의 코드 샘플을 검색하고 필터링하려면 Google Cloud 샘플 브라우저를 참조하세요.