预测图片分类。
深入探索
如需查看包含此代码示例的详细文档,请参阅以下内容:
代码示例
Go
如需了解详情,请参阅 AutoML Vision Go API 参考文档。
如需向 AutoML Vision 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
import (
"context"
"fmt"
"io"
"io/ioutil"
"os"
automl "cloud.google.com/go/automl/apiv1"
"cloud.google.com/go/automl/apiv1/automlpb"
)
// visionClassificationPredict does a prediction for image classification.
func visionClassificationPredict(w io.Writer, projectID string, location string, modelID string, filePath string) error {
// projectID := "my-project-id"
// location := "us-central1"
// modelID := "ICN123456789..."
// filePath := "path/to/image.jpg"
ctx := context.Background()
client, err := automl.NewPredictionClient(ctx)
if err != nil {
return fmt.Errorf("NewPredictionClient: %w", err)
}
defer client.Close()
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("Open: %w", err)
}
defer file.Close()
bytes, err := ioutil.ReadAll(file)
if err != nil {
return fmt.Errorf("ReadAll: %w", err)
}
req := &automlpb.PredictRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/models/%s", projectID, location, modelID),
Payload: &automlpb.ExamplePayload{
Payload: &automlpb.ExamplePayload_Image{
Image: &automlpb.Image{
Data: &automlpb.Image_ImageBytes{
ImageBytes: bytes,
},
},
},
},
// Params is additional domain-specific parameters.
Params: map[string]string{
// score_threshold is used to filter the result.
"score_threshold": "0.8",
},
}
resp, err := client.Predict(ctx, req)
if err != nil {
return fmt.Errorf("Predict: %w", err)
}
for _, payload := range resp.GetPayload() {
fmt.Fprintf(w, "Predicted class name: %v\n", payload.GetDisplayName())
fmt.Fprintf(w, "Predicted class score: %v\n", payload.GetClassification().GetScore())
}
return nil
}
Java
如需了解详情,请参阅 AutoML Vision Java API 参考文档。
如需向 AutoML Vision 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
import com.google.cloud.automl.v1.AnnotationPayload;
import com.google.cloud.automl.v1.ExamplePayload;
import com.google.cloud.automl.v1.Image;
import com.google.cloud.automl.v1.ModelName;
import com.google.cloud.automl.v1.PredictRequest;
import com.google.cloud.automl.v1.PredictResponse;
import com.google.cloud.automl.v1.PredictionServiceClient;
import com.google.protobuf.ByteString;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
class VisionClassificationPredict {
public static void main(String[] args) throws IOException {
// TODO(developer): Replace these variables before running the sample.
String projectId = "YOUR_PROJECT_ID";
String modelId = "YOUR_MODEL_ID";
String filePath = "path_to_local_file.jpg";
predict(projectId, modelId, filePath);
}
static void predict(String projectId, String modelId, String filePath) throws IOException {
// Initialize client that will be used to send requests. This client only needs to be created
// once, and can be reused for multiple requests. After completing all of your requests, call
// the "close" method on the client to safely clean up any remaining background resources.
try (PredictionServiceClient client = PredictionServiceClient.create()) {
// Get the full path of the model.
ModelName name = ModelName.of(projectId, "us-central1", modelId);
ByteString content = ByteString.copyFrom(Files.readAllBytes(Paths.get(filePath)));
Image image = Image.newBuilder().setImageBytes(content).build();
ExamplePayload payload = ExamplePayload.newBuilder().setImage(image).build();
PredictRequest predictRequest =
PredictRequest.newBuilder()
.setName(name.toString())
.setPayload(payload)
.putParams(
"score_threshold", "0.8") // [0.0-1.0] Only produce results higher than this value
.build();
PredictResponse response = client.predict(predictRequest);
for (AnnotationPayload annotationPayload : response.getPayloadList()) {
System.out.format("Predicted class name: %s\n", annotationPayload.getDisplayName());
System.out.format(
"Predicted class score: %.2f\n", annotationPayload.getClassification().getScore());
}
}
}
}
Node.js
如需了解详情,请参阅 AutoML Vision Node.js API 参考文档。
如需向 AutoML Vision 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
/**
* TODO(developer): Uncomment these variables before running the sample.
*/
// const projectId = 'YOUR_PROJECT_ID';
// const location = 'us-central1';
// const modelId = 'YOUR_MODEL_ID';
// const filePath = 'path_to_local_file.jpg';
// Imports the Google Cloud AutoML library
const {PredictionServiceClient} = require('@google-cloud/automl').v1;
const fs = require('fs');
// Instantiates a client
const client = new PredictionServiceClient();
// Read the file content for translation.
const content = fs.readFileSync(filePath);
async function predict() {
// Construct request
// params is additional domain-specific parameters.
// score_threshold is used to filter the result
const request = {
name: client.modelPath(projectId, location, modelId),
payload: {
image: {
imageBytes: content,
},
},
};
const [response] = await client.predict(request);
for (const annotationPayload of response.payload) {
console.log(`Predicted class name: ${annotationPayload.displayName}`);
console.log(
`Predicted class score: ${annotationPayload.classification.score}`
);
}
}
predict();
Python
如需了解详情,请参阅 AutoML Vision Python API 参考文档。
如需向 AutoML Vision 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证。
from google.cloud import automl
# TODO(developer): Uncomment and set the following variables
# project_id = "YOUR_PROJECT_ID"
# model_id = "YOUR_MODEL_ID"
# file_path = "path_to_local_file.jpg"
prediction_client = automl.PredictionServiceClient()
# Get the full path of the model.
model_full_id = automl.AutoMlClient.model_path(project_id, "us-central1", model_id)
# Read the file.
with open(file_path, "rb") as content_file:
content = content_file.read()
image = automl.Image(image_bytes=content)
payload = automl.ExamplePayload(image=image)
# params is additional domain-specific parameters.
# score_threshold is used to filter the result
# https://cloud.google.com/automl/docs/reference/rpc/google.cloud.automl.v1#predictrequest
params = {"score_threshold": "0.8"}
request = automl.PredictRequest(name=model_full_id, payload=payload, params=params)
response = prediction_client.predict(request=request)
print("Prediction results:")
for result in response.payload:
print(f"Predicted class name: {result.display_name}")
print(f"Predicted class score: {result.classification.score}")
后续步骤
如需搜索和过滤其他 Google Cloud 产品的代码示例,请参阅 Google Cloud 示例浏览器。