取得線上預測

AI Platform 線上預測這項服務經過最佳化處理,會儘量以最短的延遲時間透過託管模型執行您的資料。當您將小批次的資料傳送給這項服務後,服務會在回應中傳回您的預測結果。

瞭解線上與批次預測或參閱預測概念總覽

事前準備

為了要求預測,您必須先:

  • 透過最佳化要進行預測的圖形,確認 SavedModel 的檔案大小不超過 AI Platform 的預設上限 250 MB

  • 驗證輸入資料是否採用線上預測的正確格式

地區

目前 AI Platform 線上預測服務可在下列地區使用:

  • us-central1
  • us-east1
  • us-east4
  • asia-northeast1
  • europe-west1

如要完全瞭解 AI Platform 訓練和預測服務的可用地區,請參閱地區指南

建立模型和版本

建立模型和版本資源時,請針對如何執行線上預測功能做出以下重大決策:

已建立的資源 在建立資源時指定的決策
模型 執行預測的地區
模型 啟用線上預測記錄
版本 要使用的執行階段版本
版本 要使用的 Python 版本
版本 用於線上預測的機器類型

初始建立模型或版本後,您就無法更新上方所列的設定。如果您需要變更這些設定,請使用新設定來建立新的模型或版本資源,並重新部署您的模型。

適用於線上預測的機器類型

線上預測目前支援單核和四核 CPU。如果您有興趣參加其他硬體的 Alpha 版計畫,請聯絡 AI Platform 意見回饋

名稱 核心數 RAM (GB)
mls1-c1-m2 (預設) 1 2
mls1-c4-m2 (Beta 版) 4 2

如要在線上預測中使用四核 CPU,請在建立 AI Platform 模型版本時將機器類型指定為 mls1-c4-m2

gcloud

如要指定四核 CPU,請使用 gcloud beta 指令,並於指令中加入下列選用標記:--machine-type "mls1-c4-m2"。例如:

  gcloud components install beta

  gcloud beta ai-platform versions create "[YOUR_VERSION_NAME]" \
      --model "[YOUR_MODEL_NAME]" \
      --origin "[YOUR_GCS_PATH_TO_MODEL_DIRECTORY]" \
      --runtime-version "1.13" \
      --python-version "3.5" \
      --machine-type "mls1-c4-m2"

Python

本範例使用 Google API Python 專用用戶端程式庫。請參閱如何建立模型版本的完整操作說明

如要指定四核 CPU,請將下列選用性項目加入您傳送給版本建立要求的 requestDict'machineType': 'mls1-c4-m2'。例如:

   requestDict = {'name': '[YOUR_VERSION_NAME]',
      'description': '[YOUR_VERSION_DESCRIPTION]',
      'deploymentUri': '[YOUR_GCS_PATH_TO_MODEL_DIRECTORY]',
      'runtimeVersion': '1.13',
      'pythonVersion': '3.5',
      'machineType': 'mls1-c4-m2'}

查看這些機器類型的計費資訊。

線上預測要求的要求記錄

根據預設,AI Platform 預測服務不會提供有關要求的記錄資訊,因為記錄會產生費用。以較高的每秒查詢次數 (QPS) 執行的線上預測服務會產生大量記錄,這些記錄受到 Stackdriver 定價政策的規範。

如要選擇啟用線上預測記錄,您可以在建立模型資源時,將模型設定為產生記錄。您可以單獨啟用兩種類型的記錄:

  • 存取記錄,包含每個要求時間戳記與延遲時間等資訊。

  • 串流記錄,包含來自預測節點的 stderrstdout 串流,且可用於偵錯。這類記錄仍在 Beta 版測試階段。

gcloud

如要啟用存取記錄,請在以 gcloud ai-platform models create 指令建立模型時包含 --enable-logging 標記。例如:

gcloud ai-platform models create model_name \
  --regions us-central1 \
  --enable-logging

如要啟用串流記錄 (Beta 版),請使用 gcloud beta 元件並包含 --enable-console-logging 標記。例如:

gcloud components install beta

gcloud beta ai-platform models create model_name \
  --regions us-central1 \
  --enable-console-logging

REST API

如要啟用存取記錄,請在使用 projects.models.create 建立模型時,在 Model 資源中將 onlinePredictionLogging 設定為 True

如要啟用串流記錄 (Beta 版),請在 Model 資源中將 onlinePredictionConsoleLogging 欄位設定為 True

將用於線上預測的輸入內容格式化

將樣本格式化為 JSON 字串

線上預測的基本格式是一份樣本資料張量清單。張量清單可以是單純的值清單或 JSON 物件成員清單,視您在訓練應用程式中設定輸入的方式而定。

以下範例顯示輸入張量和樣本索引鍵:

{"values": [1, 2, 3, 4], "key": 1}

如果 JSON 字串遵循下列規則,其組成內容可能會非常複雜:

  • 樣本資料的頂層必須是 JSON 物件,亦即鍵/值組合的字典。

  • 樣本物件的個別值可為字串、數字或清單。您無法嵌入 JSON 物件。

  • 清單僅能包含相同類型的項目 (包括其他清單)。您不能混合使用字串和數值。

將用於線上預測的輸入樣本做為 projects.predict 呼叫的訊息內文傳送。進一步瞭解要求主體的格式需求

gcloud

  1. 確定您的輸入檔是文字檔,其中每個樣本是一個 JSON 物件,每行一個樣本。

    {"values": [1, 2, 3, 4], "key": 1}
    {"values": [5, 6, 7, 8], "key": 2}
    

REST API

  1. 將每個樣本設為清單中的一個項目,並將清單成員命名為 instances

    {"instances": [
      {"values": [1, 2, 3, 4], "key": 1},
      {"values": [5, 6, 7, 8], "key": 2}
    ]}
    

預測輸入的二進位資料

二進位資料無法格式化為 JSON 支援的 UTF-8 編碼字串。如果輸入內容中有二進位資料,您必須使用 base64 編碼來表示。下列是必要的特殊格式設定:

  • 您的編碼字串必須格式化為具有 b64 單一索引鍵的 JSON 物件。下列 Python 2.7 範例使用 base64 程式庫對原始 JPEG 資料的緩衝區進行編碼,以建立樣本:

    {"image_bytes": {"b64": base64.b64encode(jpeg_data)}}
    

    在 Python 3.5 中,base64 編碼會輸出一個位元組序列。您必須將這個序列轉換成字串,讓它可透過 JSON 序列化:

    {'image_bytes': {'b64': base64.b64encode(jpeg_data).decode()}}
    
  • 在 TensorFlow 模型程式碼中,您必須為二進位輸入和輸出張量提供結尾為「_bytes」的別名。

要求預測

predict 要求中,將輸入資料樣本以 JSON 字串傳送,以要求線上預測。如要格式化要求及回應主體,請參閱預測要求的詳細資料一文。

如果您沒有指定模型版本,則預測要求會使用模型的預設版本

gcloud

  1. 建立環境變數以保留參數,包括版本值 (如果您決定指定特定模型版本):

    MODEL_NAME="[YOUR-MODEL-NAME]"
    INPUT_DATA_FILE="instances.json"
    VERSION_NAME="[YOUR-VERSION-NAME]"
    
  2. 使用 gcloud ai-platform predict 將樣本傳送至已部署的模型。請注意,--version 為選用。

    gcloud ai-platform predict --model $MODEL_NAME  \
                       --version $VERSION_NAME \
                       --json-instances $INPUT_DATA_FILE
    
  3. gcloud 工具會剖析回應,並以使用者可理解的格式,將預測結果輸出到您的終端機。您可以在 predict 指令中使用 --format 標記來指定其他輸出格式,例如 JSON 或 CSV。查看支援的輸出格式

Python

您可以使用 Python 適用的 Google API 用戶端程式庫呼叫 AI Platform Training 和 Prediction API,不需要手動建構 HTTP 要求。您必須先設定驗證,才能執行下列程式碼範例。

def predict_json(project, model, instances, version=None):
    """Send json data to a deployed model for prediction.

    Args:
        project (str): project where the AI Platform Model is deployed.
        model (str): model name.
        instances ([Mapping[str: Any]]): Keys should be the names of Tensors
            your deployed model expects as inputs. Values should be datatypes
            convertible to Tensors, or (potentially nested) lists of datatypes
            convertible to tensors.
        version: str, version of the model to target.
    Returns:
        Mapping[str: any]: dictionary of prediction results defined by the
            model.
    """
    # Create the AI Platform service object.
    # To authenticate set the environment variable
    # GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
    service = googleapiclient.discovery.build('ml', 'v1')
    name = 'projects/{}/models/{}'.format(project, model)

    if version is not None:
        name += '/versions/{}'.format(version)

    response = service.projects().predict(
        name=name,
        body={'instances': instances}
    ).execute()

    if 'error' in response:
        raise RuntimeError(response['error'])

    return response['predictions']

Java

您可以使用 Java 適用的 Google API 用戶端程式庫呼叫 AI Platform Training 和 Prediction API,不需要手動建構 HTTP 要求。您必須先設定驗證,才能執行下列程式碼範例。

/*
 * Copyright 2017 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import com.google.api.client.googleapis.auth.oauth2.GoogleCredential;
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport;
import com.google.api.client.http.FileContent;
import com.google.api.client.http.GenericUrl;
import com.google.api.client.http.HttpContent;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpRequestFactory;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.UriTemplate;
import com.google.api.client.json.JsonFactory;
import com.google.api.client.json.jackson2.JacksonFactory;
import com.google.api.services.discovery.Discovery;
import com.google.api.services.discovery.model.JsonSchema;
import com.google.api.services.discovery.model.RestDescription;
import com.google.api.services.discovery.model.RestMethod;
import java.io.File;

/*
 * Sample code for doing Cloud Machine Learning Engine online prediction in Java.
 */

public class OnlinePredictionSample {
  public static void main(String[] args) throws Exception {
    HttpTransport httpTransport = GoogleNetHttpTransport.newTrustedTransport();
    JsonFactory jsonFactory = JacksonFactory.getDefaultInstance();
    Discovery discovery = new Discovery.Builder(httpTransport, jsonFactory, null).build();

    RestDescription api = discovery.apis().getRest("ml", "v1").execute();
    RestMethod method = api.getResources().get("projects").getMethods().get("predict");

    JsonSchema param = new JsonSchema();
    String projectId = "YOUR_PROJECT_ID";
    // You should have already deployed a model and a version.
    // For reference, see https://cloud.google.com/ml-engine/docs/deploying-models.
    String modelId = "YOUR_MODEL_ID";
    String versionId = "YOUR_VERSION_ID";
    param.set(
        "name", String.format("projects/%s/models/%s/versions/%s", projectId, modelId, versionId));

    GenericUrl url =
        new GenericUrl(UriTemplate.expand(api.getBaseUrl() + method.getPath(), param, true));
    System.out.println(url);

    String contentType = "application/json";
    File requestBodyFile = new File("input.txt");
    HttpContent content = new FileContent(contentType, requestBodyFile);
    System.out.println(content.getLength());

    GoogleCredential credential = GoogleCredential.getApplicationDefault();
    HttpRequestFactory requestFactory = httpTransport.createRequestFactory(credential);
    HttpRequest request = requestFactory.buildRequest(method.getHttpMethod(), url, content);

    String response = request.execute().parseAsString();
    System.out.println(response);
  }
}

線上預測疑難排解

線上預測的常見錯誤如下:

  • 記憶體不足錯誤
  • 輸入資料的格式不正確

請試著縮減模型大小,再將模型部署至 AI Platform 以進行預測。

詳情請參閱線上預測疑難排解一文。

後續步驟

本頁內容對您是否有任何幫助?請提供意見:

傳送您對下列選項的寶貴意見...

這個網頁
TensorFlow 適用的 AI Platform