获取在线预测结果

AI Platform Prediction 在线预测是一项经过优化的服务,可通过托管模型运行您的数据,并且将延迟控制在最低限度。您只需将小批量数据发送到服务,之后服务将做出响应并返回预测。

了解在线预测与批量预测对比或者参阅预测概念的概览

准备工作

要请求预测,您首先必须完成以下操作:

地区

部分区域提供在线预测服务。此外,每个区域还提供不同的机器类型。如需了解在线预测在每个区域的可用性,请参阅区域指南

创建模型和版本

在创建模型和版本资源时,您要针对如何运行在线预测做出以下重要决策:

已创建的资源 在创建资源时做出的决策
模型 运行预测的区域
模型 启用在线预测日志记录
版本 要使用的运行时版本
版本 要使用的 Python 版本
版本 用于在线预测的机器类型

在首次创建模型或版本后,您无法更新上面列出的设置。如果需要更改这些设置,请在创建新模型或版本资源时使用新的设置,然后重新部署模型。

可用于在线预测的机器类型

在创建版本时,您可以选择 AI Platform Prediction 用于在线预测节点的虚拟机类型。详细了解机器类型

获取在线预测请求的日志

默认情况下,AI Platform Prediction 预测服务不提供有关请求的日志信息,以免产生费用。这是因为,以较高的每秒查询次数 (QPS) 进行的在线预测可能会生成大量日志,而这些日志要按照 Cloud Logging 价格BigQuery 价格计费。

如果要启用在线预测日志记录,则必须在创建模型资源创建模型版本资源时对其进行配置,具体取决于您想要启用的日志记录类型。日志记录有三种类型,您可以单独启用这三类日志记录:

  • 访问日志记录,用于将每个请求的时间戳和延迟时间等信息记录到 Cloud Logging 中。

    您可以在创建模型资源时启用访问日志记录。

  • 控制台日志记录,用于将来自预测节点的 stderrstdout 信息流记录到 Cloud Logging 中,这对于调试非常有用。此类型的日志记录对于 Compute Engine (N1) 机器类型为预览版,对于旧版 (MLS1) 机器类型则为正式版。

    您可以在创建模型资源时启用控制台日志记录。

  • 请求-响应日志记录,用于将在线预测请求和响应的样本记录到 BigQuery 表中。此类日志记录处于测试阶段。

    若要启用请求-响应日志记录,您可以创建一个模型版本资源,然后更新该版本

gcloud

要启用访问日志记录,请在使用 gcloud ai-platform models create 命令创建模型时添加 --enable-logging 标志。例如:

gcloud ai-platform models create MODEL_NAME \
  --region=us-central1 \
  --enable-logging

要启用控制台日志记录(预览),请使用 gcloud beta 组件并添加 --enable-console-logging 标志。例如:

gcloud components install beta

gcloud beta ai-platform models create MODEL_NAME \
  --region=us-central1 \
  --enable-console-logging

目前,您无法使用 gcloud 工具启用请求-响应日志记录(测试版)。您只能在向 REST API 发送 projects.models.versions.patch 请求时启用此类日志记录。

REST API

要启用访问日志记录,请在使用 projects.models.create 创建模型时,在“模型”资源中将 onlinePredictionLogging 设置为 True

若要启用控制台日志记录(Beta 版),请在“模型”资源中将 onlinePredictionConsoleLogging 字段设置为 True

请求-响应日志记录

与其他类型的日志记录不同,您无法在创建模型时启用请求-响应日志记录。要启用这种日志记录,您可以对现有的模型版本使用 projects.models.versions.patch 方法。(您必须先使用 Google Cloud Console、gcloud 工具或 REST API 创建模型版本。)

要启用请求-响应日志记录,请使用以下条目填写版本资源的 requestLoggingConfig 字段

  • samplingPercentage:此标志表示介于 0 和 1 之间的一个数字,用于定义要记录的请求的比例。例如,将此值设置为 1 以便记录所有请求,或设置为 0.1 以便记录 10% 的请求。
  • bigqueryTableName:您要在其中记录请求和响应的 BigQuery 表的完全限定名称 (PROJECT_ID.DATASET_NAME.TABLE_NAME)。该表必须已存在且具有以下架构

    字段名称类型模式
    model字符串必填
    model_version字符串必填
    time时间戳必填
    raw_data字符串必填
    raw_prediction字符串可以为 Null
    groundtruth字符串可以为 Null

    了解如何创建 BigQuery 表

使用 What-If 工具检查模型

您可以通过交互式信息中心,在笔记本环境中使用 What-If 工具 (WIT) 来检查 AI Platform Prediction 模型。What-If 工具与 TensorBoard、Jupyter 笔记本、Colab 笔记本和 JupyterHub 集成。它还预装在 AI Platform Notebooks TensorFlow 实例上。

了解如何将 What-If 工具用于 AI Platform

格式化输入以进行在线预测

将实例格式化为 JSON 字符串

在线预测的基本格式是数据实例列表。这些列表可以是普通的值列表,也可以是 JSON 对象的成员,具体取决于您在训练应用中配置输入的方式。TensorFlow 模型和自定义预测例程可接受更复杂的输入,而大多数 scikit-learn 和 XGBoost 模型希望使用数字列表作为输入。

以下示例显示了 TensorFlow 模型的输入张量和实例键:

{"values": [1, 2, 3, 4], "key": 1}

只要遵循以下规则,JSON 字符串的构成可能会很复杂:

  • 顶级实例数据必须是 JSON 对象,即键值对的字典。

  • 实例对象中的各个值可以是字符串、数字或列表。 您无法嵌入 JSON 对象。

  • 列表必须仅包含相同类型的内容(包括其他列表)。您不能混合使用字符串和数值。

您将在线预测的输入实例作为 projects.predict 调用的消息正文进行传递。详细了解请求正文的格式要求

gcloud

您可以采用两种不同的方式来设置输入的格式,具体取决于您计划如何发送预测请求。我们建议您使用 gcloud ai-platform predict 命令的 --json-request 标志。另外,您可以将 --json-instances 标志与以换行符分隔的 JSON 数据结合使用。

对于 --json-request

使每个实例成为 JSON 数组中的一项,并将该数组作为 JSON 文件的 instances 字段提供。例如:

instances.json

{"instances": [
  {"values": [1, 2, 3, 4], "key": 1},
  {"values": [5, 6, 7, 8], "key": 2}
]}

对于 --json-instances

确保输入文件是以换行符分隔的 JSON 文件,其中每个实例都是 JSON 对象,并且每行一个实例。例如:

instances.jsonl

{"values": [1, 2, 3, 4], "key": 1}
{"values": [5, 6, 7, 8], "key": 2}

REST API

使每个实例成为 JSON 数组中的一项,并将该数组作为 JSON 对象的 instances 字段提供。例如:

{"instances": [
  {"values": [1, 2, 3, 4], "key": 1},
  {"values": [5, 6, 7, 8], "key": 2}
]}

预测输入中的二进制数据

二进制数据不能格式化为 JSON 支持的 UTF-8 编码字符串。如果输入中包含二进制数据,则必须使用 base64 编码来表示它。此时需要用到以下特殊格式:

  • 编码的字符串必须设置为 JSON 对象格式,并包含名为 b64 的单个键。以下 Python 2.7 示例使用 base64 库对原始 JPEG 数据的缓冲区进行编码以生成实例:

    {"image_bytes": {"b64": base64.b64encode(jpeg_data)}}
    

    在 Python 3 中,base64 编码会输出一个字节序列。您必须将此字节序列转换为字符串,使其能够进行 JSON 序列化:

    {'image_bytes': {'b64': base64.b64encode(jpeg_data).decode()}}
    
  • 在 TensorFlow 模型代码中,您必须为二进制输入和输出张量提供别名,以便它们以“_bytes”结尾。

请求预测

可通过在预测请求中将输入数据实例作为 JSON 字符串发送来请求在线预测。如需了解请求和响应正文的格式,请参阅预测请求的详细信息

如果未指定模型版本,则预测请求将使用模型的默认版本

gcloud

  1. 创建环境变量来保存参数,如果您决定指定特定的模型版本,请添加版本值:

    MODEL_NAME="[YOUR-MODEL-NAME]"
    INPUT_DATA_FILE="instances.json"
    VERSION_NAME="[YOUR-VERSION-NAME]"
    
  2. 使用 gcloud ai-platform predict 将实例发送到已部署的模型。请注意,--version 是可选的。

    gcloud ai-platform predict \
      --model=$MODEL_NAME \
      --version=$VERSION_NAME \
      --json-request=$INPUT_DATA_FILE \
      --region=REGION
    

    REGION 替换为您在其上创建了模型地区端点的区域。如果您在全球端点上创建模型,请省略 --region 标志。

  3. gcloud 工具会解析响应并以简明易懂的格式将预测结果输出到终端。您可以通过在 predict 命令中使用 --format 标志来指定不同的输出格式,例如 JSON 或 CSV。查看可用的输出格式

Python

您可以使用 Python 版 Google API 客户端库调用 AI Platform Training and Prediction API,而无需手动构建 HTTP 请求。在运行以下代码示例之前,必须先设置身份验证。

# Create the AI Platform service object.
# To authenticate set the environment variable
# GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
service = googleapiclient.discovery.build('ml', 'v1')

def predict_json(project, model, instances, version=None):
    """Send json data to a deployed model for prediction.

    Args:
        project (str): project where the AI Platform Model is deployed.
        model (str): model name.
        instances ([Mapping[str: Any]]): Keys should be the names of Tensors
            your deployed model expects as inputs. Values should be datatypes
            convertible to Tensors, or (potentially nested) lists of datatypes
            convertible to tensors.
        version: str, version of the model to target.
    Returns:
        Mapping[str: any]: dictionary of prediction results defined by the
            model.
    """
    name = 'projects/{}/models/{}'.format(project, model)

    if version is not None:
        name += '/versions/{}'.format(version)

    response = service.projects().predict(
        name=name,
        body={'instances': instances}
    ).execute()

    if 'error' in response:
        raise RuntimeError(response['error'])

    return response['predictions']

Java

您可以使用 Java 版 Google API 客户端库调用 AI Platform Training and Prediction API,而无需手动构建 HTTP 请求。在运行以下代码示例之前,必须先设置身份验证。

/*
 * Copyright 2017 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport;
import com.google.api.client.http.FileContent;
import com.google.api.client.http.GenericUrl;
import com.google.api.client.http.HttpContent;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpRequestFactory;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.UriTemplate;
import com.google.api.client.json.JsonFactory;
import com.google.api.client.json.jackson2.JacksonFactory;
import com.google.api.services.discovery.Discovery;
import com.google.api.services.discovery.model.JsonSchema;
import com.google.api.services.discovery.model.RestDescription;
import com.google.api.services.discovery.model.RestMethod;
import com.google.auth.http.HttpCredentialsAdapter;
import com.google.auth.oauth2.GoogleCredentials;
import java.io.File;
import java.util.ArrayList;
import java.util.List;

/*
 * Sample code for sending an online prediction request to Cloud Machine Learning Engine.
 */

public class OnlinePredictionSample {
  public static void main(String[] args) throws Exception {
    HttpTransport httpTransport = GoogleNetHttpTransport.newTrustedTransport();
    JsonFactory jsonFactory = JacksonFactory.getDefaultInstance();
    Discovery discovery = new Discovery.Builder(httpTransport, jsonFactory, null).build();

    RestDescription api = discovery.apis().getRest("ml", "v1").execute();
    RestMethod method = api.getResources().get("projects").getMethods().get("predict");

    JsonSchema param = new JsonSchema();
    String projectId = "YOUR_PROJECT_ID";
    // You should have already deployed a model and a version.
    // For reference, see https://cloud.google.com/ml-engine/docs/deploying-models.
    String modelId = "YOUR_MODEL_ID";
    String versionId = "YOUR_VERSION_ID";
    param.set(
        "name", String.format("projects/%s/models/%s/versions/%s", projectId, modelId, versionId));

    GenericUrl url =
        new GenericUrl(UriTemplate.expand(api.getBaseUrl() + method.getPath(), param, true));
    System.out.println(url);

    String contentType = "application/json";
    File requestBodyFile = new File("input.txt");
    HttpContent content = new FileContent(contentType, requestBodyFile);
    System.out.println(content.getLength());

    List<String> scopes = new ArrayList<>();
    scopes.add("https://www.googleapis.com/auth/cloud-platform");

    GoogleCredentials credential = GoogleCredentials.getApplicationDefault().createScoped(scopes);
    HttpRequestFactory requestFactory =
        httpTransport.createRequestFactory(new HttpCredentialsAdapter(credential));
    HttpRequest request = requestFactory.buildRequest(method.getHttpMethod(), url, content);

    String response = request.execute().parseAsString();
    System.out.println(response);
  }
}

对在线预测进行问题排查

在线预测中的常见错误包括:

  • 内存不足错误
  • 输入数据的格式不正确
  • 一项在线预测请求所包含的数据量不能超过 1.5 MB。 如果使用 gcloud 工具创建请求,则每个文件最多只能处理 100 个实例。若要同时为更多实例获取预测结果,请使用批量预测。

在将模型部署到 AI Platform Prediction 进行预测之前,请尝试缩减模型大小

如需了解详情,请参阅排查在线预测问题

后续步骤