取得線上預測

AI Platform 線上預測這項服務經過最佳化處理,會儘量以最短的延遲時間透過託管模型執行您的資料。當您將小批次的資料傳送給這項服務後,服務會在回應中傳回您的預測結果。

瞭解線上與批次預測的差異,或是瀏覽預測概念總覽

事前準備

如要提出預測要求,您必須先執行以下作業:

地區

目前 AI Platform 線上預測服務可在下列地區使用:

  • us-central1
  • us-east1
  • us-east4
  • asia-northeast1
  • europe-west1

用於線上預測 (Beta 版) 的 Compute Engine (N1) 機器類型僅於 us-central1 提供。

如要完全瞭解 AI Platform 訓練和預測服務的可用地區,請參閱地區指南

建立模型和版本

建立模型和版本資源時,您會做出幾項與如何執行線上預測的相關的重大決策,這些決策如下:

建立的資源 在建立資源時指定的決策
模型 執行預測的地區
模型 啟用線上預測記錄
版本 要使用的執行階段版本
版本 要使用的 Python 版本
版本 用於線上預測的機器類型

一開始建立模型或版本後,您就無法更新上方所列的設定。如果您需要變更這些設定,請使用新設定來建立新的模型或版本資源,並重新部署您的模型。

適用於線上預測的機器類型

建立版本時,您可以選擇 AI Platform Prediction 要用於線上預測節點的虛擬機器類型。進一步瞭解機器類型。

線上預測要求的要求記錄

根據預設,AI Platform Prediction 服務不會提供要求的相關記錄資訊,因為記錄會產生費用。以較高的每秒查詢次數 (QPS) 執行的線上預測服務會產生大量記錄,這些記錄受到 Stackdriver 定價政策BigQuery 定價政策規範。

如要啟用線上預測記錄,您必須在建立模型資源建立模型版本資源時加以設定,要在建立何者時設定視要啟用的記錄類型而定。記錄有三種類型,您可以單獨啟用每一種。記錄類型如下:

  • 存取記錄:將每個要求的資訊 (時間戳記與延遲時間等) 記錄到 Stackdriver Logging。

    您可以在建立模型資源時啟用存取記錄。

  • 串流記錄:將預測節點的 stderrstdout 串流記錄到 Stackdriver Logging,可用於偵錯。這類記錄仍在 Beta 版階段,且不受 Compute Engine (N1) 機器類型支援。

    您可以在建立模型資源時啟用串流記錄。

  • 要求/回應記錄:對線上預測要求和回應進行取樣,並記錄到 BigQuery 資料表。這類記錄仍在 Beta 版階段。

    您可以在建立模型版本資源時啟用要求/回應記錄。

gcloud

如要啟用存取記錄,請在透過 gcloud ai-platform models create 指令建立模型時納入 --enable-logging 旗標。例如:

gcloud ai-platform models create model_name \
  --regions us-central1 \
  --enable-logging

如要啟用串流記錄 (Beta 版),請使用 gcloud beta 元件並納入 --enable-console-logging 旗標。例如:

gcloud components install beta

gcloud beta ai-platform models create model_name \
  --regions us-central1 \
  --enable-console-logging

您目前還無法使用 gcloud 工具啟用 要求/回應記錄 (Beta 版)。只有在向 REST API 傳送 projects.models.versions.create 要求時,才能啟用此類記錄。

REST API

如要啟用存取記錄,請在透過 projects.models.create 建立模型時,將模型資源中的 onlinePredictionLogging 設為 True

如要啟用串流記錄 (Beta 版),請將模型資源中的 onlinePredictionConsoleLogging 欄位設為 True

要求/回應記錄

要求/回應記錄與其他類型的記錄不同,您無法在建立模型時啟用要求/回應記錄,但您可以在建立版本 (projects.models.versions.create) 時啟用這項功能。

如要啟用要求/回應記錄,請在版本資源的 requestLoggingConfig 欄位中填入以下項目:

  • samplingPercentage:介於 0 到 1 之間的數字,用於定義要記錄之要求的比例。例如,將值設為 1 可記錄所有要求,設為 0.1 則可記錄 10% 的要求。
  • bigqueryTableName:您要在其中記錄要求和回應的 BigQuery 資料表完整名稱 (project_id.dataset_name.table_name)。具有下列結構定義的資料表必須已存在。

    欄位名稱類型模式
    modelSTRINGREQUIRED
    model_versionSTRINGREQUIRED
    timeTIMESTAMPREQUIRED
    raw_dataSTRINGREQUIRED
    raw_predictionSTRINGNULLABLE
    groundtruthSTRINGNULLABLE

    瞭解如何建立 BigQuery 資料表

使用 What-If Tool 檢查模型

您可以在筆記本環境中使用 What-If Tool (WIT),透過互動式資訊主頁來檢查 AI Platform 模型。What-If Tool 整合了 TensorBoard、Jupyter 筆記本、Colab 筆記本,以及 JupyterHub。此工具也會預先安裝在 AI Platform Notebooks TensorFlow 執行個體上。

瞭解如何將 What-If Tool 與 AI Platform 搭配使用

讓輸入內容採用線上預測適用的格式

讓樣本採用 JSON 字串格式

線上預測的基本格式是一份資料樣本清單。視您在訓練應用程式中設定輸入內容的方式而定,樣本可以是簡單的值清單或 JSON 物件的內含元素。TensorFlow 模型和自訂預測處理常式可以接受較為複雜的輸入內容,而大部分 scikit-learn 和 XGBoost 模型採用的輸入內容格式都是數字清單。

這個範例顯示了 TensorFlow 模型的輸入張量和樣本鍵:

{"values": [1, 2, 3, 4], "key": 1}

如果 JSON 字串遵循下列規則,其組成內容可以非常複雜:

  • 樣本資料的頂層必須是 JSON 物件,也就是鍵/值組合的字典。

  • 樣本物件的個別值可為字串、數字或清單。您無法嵌入 JSON 物件。

  • 清單僅能包含相同類型的項目 (包括其他清單)。不能混合使用字串值和數值。

您要將線上預測的輸入樣本做為 projects.predict 呼叫的訊息主體傳送。進一步瞭解要求主體的格式需求

gcloud

  1. 確定您的輸入檔案是以換行符號分隔的 JSON 檔案,當中的每個樣本都是 JSON 物件,且每行一個樣本。

    {"values": [1, 2, 3, 4], "key": 1}
    {"values": [5, 6, 7, 8], "key": 2}
    

REST API

  1. 將每個樣本設為清單中的一個項目,並將清單成員命名為 instances

    {"instances": [
      {"values": [1, 2, 3, 4], "key": 1},
      {"values": [5, 6, 7, 8], "key": 2}
    ]}
    

預測輸入的二進位資料

二進位資料無法採用 JSON 支援的 UTF-8 編碼字串格式。如果您的輸入中有二進位資料,則必須使用 base64 編碼來表示。下列是必要的特殊格式設定:

  • 編碼字串的格式必須是 JSON 物件,且具有 b64 單一索引鍵的。下列 Python 2.7 範例使用 base64 程式庫對原始 JPEG 資料的緩衝區進行編碼,以建立樣本:

    {"image_bytes": {"b64": base64.b64encode(jpeg_data)}}
    

    在 Python 3.5 中,base64 編碼會輸出一個位元組序列。您必須將這個序列轉換成字串,讓它可透過 JSON 序列化:

    {'image_bytes': {'b64': base64.b64encode(jpeg_data).decode()}}
    
  • 在 TensorFlow 模型程式碼中,您必須為二進位輸入和輸出張量提供結尾為「_bytes」的別名。

要求預測

predict 要求中將輸入資料樣本以 JSON 字串的形式傳送,藉此提出線上預測要求。如要瞭解如何設定要求和回應主體的格式,請參閱預測要求細節一文。

如果您未指定模型版本,預測要求會使用模型的預設版本

gcloud

  1. 建立環境變數以放置參數,包括版本值 (如果您決定指定特定模型版本):

    MODEL_NAME="[YOUR-MODEL-NAME]"
    INPUT_DATA_FILE="instances.json"
    VERSION_NAME="[YOUR-VERSION-NAME]"
    
  2. 使用 gcloud ai-platform predict 將樣本傳送至已部署的模型。請注意,--version 為選用。

    gcloud ai-platform predict --model $MODEL_NAME  \
                       --version $VERSION_NAME \
                       --json-instances $INPUT_DATA_FILE
    
  3. gcloud 工具會剖析回應,並以使用者可理解的格式,將預測結果輸出到您的終端機。您可以在 predict 指令中使用 --format 標記來指定其他輸出格式,例如 JSON 或 CSV。查看可用的輸出格式

Python

您可以使用 Python 適用的 Google API 用戶端程式庫呼叫 AI Platform Training and Prediction API,不需要手動建構 HTTP 要求。您必須先設定驗證,才能執行下列程式碼範例。

def predict_json(project, model, instances, version=None):
    """Send json data to a deployed model for prediction.

    Args:
        project (str): project where the AI Platform Model is deployed.
        model (str): model name.
        instances ([Mapping[str: Any]]): Keys should be the names of Tensors
            your deployed model expects as inputs. Values should be datatypes
            convertible to Tensors, or (potentially nested) lists of datatypes
            convertible to tensors.
        version: str, version of the model to target.
    Returns:
        Mapping[str: any]: dictionary of prediction results defined by the
            model.
    """
    # Create the AI Platform service object.
    # To authenticate set the environment variable
    # GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
    service = googleapiclient.discovery.build('ml', 'v1')
    name = 'projects/{}/models/{}'.format(project, model)

    if version is not None:
        name += '/versions/{}'.format(version)

    response = service.projects().predict(
        name=name,
        body={'instances': instances}
    ).execute()

    if 'error' in response:
        raise RuntimeError(response['error'])

    return response['predictions']

Java

您可以使用 Java 適用的 Google API 用戶端程式庫呼叫 AI Platform Training and Prediction API,不需要手動建構 HTTP 要求。您必須先設定驗證,才能執行下列程式碼範例。

/*
 * Copyright 2017 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import com.google.api.client.googleapis.auth.oauth2.GoogleCredential;
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport;
import com.google.api.client.http.FileContent;
import com.google.api.client.http.GenericUrl;
import com.google.api.client.http.HttpContent;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpRequestFactory;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.UriTemplate;
import com.google.api.client.json.JsonFactory;
import com.google.api.client.json.jackson2.JacksonFactory;
import com.google.api.services.discovery.Discovery;
import com.google.api.services.discovery.model.JsonSchema;
import com.google.api.services.discovery.model.RestDescription;
import com.google.api.services.discovery.model.RestMethod;
import java.io.File;
import java.util.ArrayList;
import java.util.List;

/*
 * Sample code for sending an online prediction request to Cloud Machine Learning Engine.
 */

public class OnlinePredictionSample {
  public static void main(String[] args) throws Exception {
    HttpTransport httpTransport = GoogleNetHttpTransport.newTrustedTransport();
    JsonFactory jsonFactory = JacksonFactory.getDefaultInstance();
    Discovery discovery = new Discovery.Builder(httpTransport, jsonFactory, null).build();

    RestDescription api = discovery.apis().getRest("ml", "v1").execute();
    RestMethod method = api.getResources().get("projects").getMethods().get("predict");

    JsonSchema param = new JsonSchema();
    String projectId = "YOUR_PROJECT_ID";
    // You should have already deployed a model and a version.
    // For reference, see https://cloud.google.com/ml-engine/docs/deploying-models.
    String modelId = "YOUR_MODEL_ID";
    String versionId = "YOUR_VERSION_ID";
    param.set(
        "name", String.format("projects/%s/models/%s/versions/%s", projectId, modelId, versionId));

    GenericUrl url =
        new GenericUrl(UriTemplate.expand(api.getBaseUrl() + method.getPath(), param, true));
    System.out.println(url);

    String contentType = "application/json";
    File requestBodyFile = new File("input.txt");
    HttpContent content = new FileContent(contentType, requestBodyFile);
    System.out.println(content.getLength());

    List<String> scopes = new ArrayList<>();
    scopes.add("https://www.googleapis.com/auth/cloud-platform");

    GoogleCredential credential = GoogleCredential.getApplicationDefault().createScoped(scopes);
    HttpRequestFactory requestFactory = httpTransport.createRequestFactory(credential);
    HttpRequest request = requestFactory.buildRequest(method.getHttpMethod(), url, content);

    String response = request.execute().parseAsString();
    System.out.println(response);
  }
}

線上預測疑難排解

線上預測的常見錯誤如下:

  • 記憶體不足錯誤
  • 輸入資料的格式不正確

請試著縮減模型大小,再將模型部署至 AI Platform 進行預測。

詳情請參閱排解線上預測問題的相關說明。

後續步驟

本頁內容對您是否有任何幫助?請提供意見:

傳送您對下列選項的寶貴意見...

這個網頁
Google Cloud Machine Learning 說明文件