オンライン予測の取得

AI Platform Prediction のオンライン予測は、可能な限り少ないレイテンシで、ホスト対象モデルによってデータを実行するように最適化されたサービスです。バッチサイズの小さいデータをサービスに送信すると、レスポンスで予測が返されます。

オンライン予測 vs バッチ予測、または予測の概要をご覧ください。

始める前に

予測をリクエストするには、まず次のことを行う必要があります。

リージョン

オンライン予測は、特定のリージョンで利用可能です。さらに、各リージョンでは異なるマシンタイプが利用可能です。各リージョンでのオンライン予測の可用性については、リージョンのガイドをご覧ください。

モデルとバージョンの作成

モデルとバージョンのリソースを作成するときに、オンライン予測の実行方法について重要な決定を行います。

作成するリソース リソース作成時に決定すべきこと
モデル 予測を実行するリージョン
モデル オンライン予測のロギングを有効にする
バージョン 使用するランタイム バージョン
バージョン 使用する Python のバージョン
バージョン オンライン予測で使用するマシンタイプ

モデルまたはバージョンを作成した後は、上記の設定を更新できません。これらの設定を変更する必要がある場合は、新しい設定を使用して新しいモデルまたはバージョン リソースを作成し、モデルを再デプロイしてください。

オンライン予測に使用できるマシンタイプ

バージョンの作成時に、AI Platform Prediction に使用させる、オンライン予測ノードの仮想マシンのタイプを選択できます。詳細については、マシンタイプをご覧ください。

オンライン予測リクエストのログのリクエスト

ログには費用が発生するため、AI Platform Prediction 予測サービスのデフォルトではリクエストに関するログ情報を提供しません。秒間クエリ数(QPS)が非常に多いオンライン予測の場合、かなりの数のログが生成される可能性があります。これらのログには Cloud Logging の料金または BigQuery の料金が適用されます。

オンライン予測ロギングを有効にする場合は、有効にするロギングのタイプによって、モデルリソースの作成時またはモデル バージョンのリソースの作成時に構成する必要があります。ロギングには次の 3 つのタイプがあり、それぞれ独立して有効にすることができます。

  • アクセス ロギング。各リクエストのタイムスタンプやレイテンシなどの情報が Cloud Logging に記録されます。

    アクセス ロギングは、モデルリソースの作成時に有効にできます。

  • コンソール ロギング。予測ノードからの stderr ストリームと stdout ストリームが Cloud Logging に記録されます。これは、デバッグに役立つタイプのロギングです。Compute Engine(N1)マシンタイプの場合、このタイプのロギングはプレビュー版です。従来の(MLS1)マシンタイプでは一般提供されています。

    コンソール ロギングは、モデルリソースの作成時に有効にできます。

  • リクエスト / レスポンス ロギング。オンライン予測リクエストとレスポンスのサンプルが BigQuery テーブルに記録されます。このタイプのロギングはベータ版です。

    リクエスト / レスポンス ロギングを有効にするには、モデル バージョン リソースを作成して、バージョンを更新します

gcloud

アクセス ロギングを有効にするには、gcloud ai-platform models create コマンドでモデルを作成する際に、--enable-logging フラグを含めます。例:

gcloud ai-platform models create MODEL_NAME \
  --region=us-central1 \
  --enable-logging

コンソール ロギング(プレビュー版)を有効にするには、gcloud beta コンポーネントを使用し、--enable-console-logging フラグを含めます。例:

gcloud components install beta

gcloud beta ai-platform models create MODEL_NAME \
  --region=us-central1 \
  --enable-console-logging

現在、gcloud CLI を使用してリクエスト / レスポンス ロギング(ベータ版)を有効にすることはできません。このタイプのロギングを有効にできるのは、projects.models.versions.patch リクエストを REST API に送信する場合のみです。

REST API

アクセス ロギングを有効にするには、projects.models.create でモデルを作成する際に、Model リソースの onlinePredictionLogging の値を True に設定します。

コンソール ロギング(ベータ版)を有効にするには、Model リソースの onlinePredictionConsoleLogging フィールドの値を True に設定します。

リクエスト / レスポンス ロギング

他のタイプのロギングとは異なり、モデルの作成時にリクエスト / レスポンス ロギングを有効にすることはできません。代わりに、既存のモデル バージョンに対して projects.models.versions.patch メソッドを使用すると有効にできます(まず、Google Cloud Console、gcloud CLI、または REST API を使用して、モデル バージョンを作成する必要があります)。

リクエスト / レスポンス ロギングを有効にするには、Version リソースの requestLoggingConfig フィールドに次のエントリを入力します。

  • samplingPercentage: ログに記録するリクエストの割合を定義する 0~1 の数値。たとえば、すべてのリクエストをログに記録する場合は、この値を 1 に設定します。リクエストの 10% をログに記録する場合は 0.1 に設定します。
  • bigqueryTableName: リクエストとレスポンスを記録する BigQuery テーブルの完全修飾名(PROJECT_ID.DATASET_NAME.TABLE_NAME)。このテーブルはすでに存在し、次のスキーマが適用されている必要があります。

    フィールド名モード
    modelSTRINGREQUIRED
    model_versionSTRINGREQUIRED
    timeTIMESTAMPREQUIRED
    raw_dataSTRINGREQUIRED
    raw_predictionSTRINGNULLABLE
    groundtruthSTRINGNULLABLE

    BigQuery テーブルの作成方法をご覧ください。

What-If ツールを使用してモデルを検査する

ノートブック環境内で What-If Tool(WIT)を使用すると、インタラクティブ ダッシュボードを介して AI Platform Prediction モデルを検査できます。What-If ツールは、TensorBoard、Jupyter ノートブック、Colab ノートブック、JupyterHub に統合されています。また、Vertex AI Workbench ユーザー管理ノートブックの TensorFlow インスタンスにもプリインストールされています。

AI Platform での What-If ツールの使用方法をご覧ください。

オンライン予測の入力フォーマット

JSON 文字列としてインスタンスをフォーマットする

オンライン予測の基本形式は、データ インスタンスをリスト化したものです。これらは、トレーニング アプリケーションで入力を構成した方法に応じて、値のプレーンリストまたは JSON オブジェクトのメンバーのいずれかになります。scikit-learn と XGBoost のほとんどのモデルは入力として数値のリストを想定していますが、TensorFlow モデルとカスタム予測ルーチンはより複雑な入力にも対応しています。

次の例は、TensorFlow モデルの入力テンソルとインスタンス キーを示しています。

{"values": [1, 2, 3, 4], "key": 1}

次の規則に従う限り、JSON 文字列の構成は複雑になってもかまいません。

  • インスタンス データの最上位レベルは、Key-Value ペアの辞書である JSON オブジェクトでなければなりません。

  • インスタンス オブジェクト内の個々の値には、文字列、数字、またはリストを使用できます。JSON オブジェクトを埋め込むことはできません。

  • リストには、同じ型(他のリストも含む)のアイテムのみが含まれている必要があります。文字列と数値を混在させることはできません。

オンライン予測の入力インスタンスを、projects.predict 呼び出しのメッセージ本文として渡します。詳細については、リクエスト本文のフォーマット要件をご覧ください。

gcloud

予測リクエストの送信方法に応じて、2 つの異なる方法で入力をフォーマットできます。gcloud ai-platform predict コマンドの --json-request フラグを使用することをおすすめします。または、改行で区切られた JSON データで --json-instances フラグを使用することもできます。

--json-request の場合

各インスタンスを JSON 配列のアイテムにし、その配列を JSON ファイルの instances フィールドとして指定します。例:

instances.json

{"instances": [
  {"values": [1, 2, 3, 4], "key": 1},
  {"values": [5, 6, 7, 8], "key": 2}
]}

--json-instances の場合

入力ファイルが改行区切りの JSON ファイルであり、各行に 1 つのインスタンスが JSON オブジェクトとして記述されていることを確認します。例:

instances.jsonl

{"values": [1, 2, 3, 4], "key": 1}
{"values": [5, 6, 7, 8], "key": 2}

REST API

各インスタンスを JSON 配列のアイテムにし、その配列を JSON オブジェクトの instances フィールドとして指定します。例:

{"instances": [
  {"values": [1, 2, 3, 4], "key": 1},
  {"values": [5, 6, 7, 8], "key": 2}
]}

予測入力でのバイナリデータ

バイナリデータは、JSON がサポートする UTF-8 エンコード文字列としてフォーマットできません。入力にバイナリデータがある場合は、base64 エンコーディングで表す必要があります。次の特殊なフォーマットが必要です。

  • エンコードされた文字列は、b64 という 1 つのキーを持つ JSON オブジェクトとしてフォーマットする必要があります。次の Python 2.7 の例では、base64 ライブラリを使用して生の JPEG データのバッファをエンコードし、インスタンスを作成しています。

    {"image_bytes": {"b64": base64.b64encode(jpeg_data)}}
    

    Python 3 では、Base64 エンコードによりバイト シーケンスが出力されます。この出力を文字列に変換して JSON にシリアル化できるようにします。

    {'image_bytes': {'b64': base64.b64encode(jpeg_data).decode()}}
    
  • TensorFlow モデルのコードでは、入力 / 出力テンソルのエイリアスの名前を「_bytes」で終わらせる必要があります。

予測のリクエスト

オンライン予測をリクエストするには、予測リクエストで入力データ インスタンスを JSON 文字列として送信します。リクエストとレスポンスの本文のフォーマットについては、予測リクエストの詳細をご覧ください。

モデルのバージョンを指定しない場合、モデルのデフォルトのバージョンが予測リクエストで使用されます。

gcloud

  1. 特定のモデル バージョンを指定する場合には、バージョン値などのパラメータを格納する環境変数を作成します。

    MODEL_NAME="[YOUR-MODEL-NAME]"
    INPUT_DATA_FILE="instances.json"
    VERSION_NAME="[YOUR-VERSION-NAME]"
    
  2. gcloud ai-platform predict を使用して、デプロイ済みモデルにインスタンスを送信します。--version は省略可能です。

    gcloud ai-platform predict \
      --model=$MODEL_NAME \
      --version=$VERSION_NAME \
      --json-request=$INPUT_DATA_FILE \
      --region=REGION
    

    REGION は、モデルを作成したリージョン エンドポイントのリージョンに置き換えます。グローバル エンドポイントにモデルを作成した場合、--region フラグは省略します。

  3. gcloud ツールはレスポンスを解析し、人が読める形式で予測をターミナルに出力します。predict コマンドで --format フラグを使用すると、JSON や CSV など別の出力形式を指定できます。使用可能な出力形式をご覧ください。

Python

Python 用 Google API クライアント ライブラリを使用すると、HTTP リクエストを手動で作成しなくても、AI Platform Training and Prediction API を呼び出すことができます。次のコードサンプルを実行する前に、認証を設定する必要があります。

# Create the AI Platform service object.
# To authenticate set the environment variable
# GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
service = googleapiclient.discovery.build("ml", "v1")

def predict_json(project, model, instances, version=None):
    """Send json data to a deployed model for prediction.

    Args:
        project (str): project where the AI Platform Model is deployed.
        model (str): model name.
        instances ([Mapping[str: Any]]): Keys should be the names of Tensors
            your deployed model expects as inputs. Values should be datatypes
            convertible to Tensors, or (potentially nested) lists of datatypes
            convertible to tensors.
        version: str, version of the model to target.
    Returns:
        Mapping[str: any]: dictionary of prediction results defined by the
            model.
    """
    name = f"projects/{project}/models/{model}"

    if version is not None:
        name += f"/versions/{version}"

    response = (
        service.projects().predict(name=name, body={"instances": instances}).execute()
    )

    if "error" in response:
        raise RuntimeError(response["error"])

    return response["predictions"]

Java

Java 用 Google API クライアント ライブラリを使用すると、HTTP リクエストを手動で作成しなくても、AI Platform のトレーニング API と予測 API を呼び出すことができます。次のサンプルコードを実行する前に、認証を設定する必要があります。

/*
 * Copyright 2017 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport;
import com.google.api.client.http.FileContent;
import com.google.api.client.http.GenericUrl;
import com.google.api.client.http.HttpContent;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpRequestFactory;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.UriTemplate;
import com.google.api.client.json.JsonFactory;
import com.google.api.client.json.gson.GsonFactory;
import com.google.api.services.discovery.Discovery;
import com.google.api.services.discovery.model.JsonSchema;
import com.google.api.services.discovery.model.RestDescription;
import com.google.api.services.discovery.model.RestMethod;
import com.google.auth.http.HttpCredentialsAdapter;
import com.google.auth.oauth2.GoogleCredentials;
import java.io.File;
import java.util.ArrayList;
import java.util.List;

/*
 * Sample code for sending an online prediction request to Cloud Machine Learning Engine.
 */

public class OnlinePredictionSample {
  public static void main(String[] args) throws Exception {
    HttpTransport httpTransport = GoogleNetHttpTransport.newTrustedTransport();
    JsonFactory jsonFactory = GsonFactory.getDefaultInstance();
    Discovery discovery = new Discovery.Builder(httpTransport, jsonFactory, null).build();

    RestDescription api = discovery.apis().getRest("ml", "v1").execute();
    RestMethod method = api.getResources().get("projects").getMethods().get("predict");

    JsonSchema param = new JsonSchema();
    String projectId = "YOUR_PROJECT_ID";
    // You should have already deployed a model and a version.
    // For reference, see https://cloud.google.com/ml-engine/docs/deploying-models.
    String modelId = "YOUR_MODEL_ID";
    String versionId = "YOUR_VERSION_ID";
    param.set(
        "name", String.format("projects/%s/models/%s/versions/%s", projectId, modelId, versionId));

    GenericUrl url =
        new GenericUrl(UriTemplate.expand(api.getBaseUrl() + method.getPath(), param, true));
    System.out.println(url);

    String contentType = "application/json";
    File requestBodyFile = new File("input.txt");
    HttpContent content = new FileContent(contentType, requestBodyFile);
    System.out.println(content.getLength());

    List<String> scopes = new ArrayList<>();
    scopes.add("https://www.googleapis.com/auth/cloud-platform");

    GoogleCredentials credential = GoogleCredentials.getApplicationDefault().createScoped(scopes);
    HttpRequestFactory requestFactory =
        httpTransport.createRequestFactory(new HttpCredentialsAdapter(credential));
    HttpRequest request = requestFactory.buildRequest(method.getHttpMethod(), url, content);

    String response = request.execute().parseAsString();
    System.out.println(response);
  }
}

オンライン予測のトラブルシューティング

オンライン予測の一般的なエラーとしては、次のものがあります。

  • メモリ不足エラー
  • 入力データのフォーマットが正しくない
  • 1 つのオンライン予測リクエストには 1.5 MB 以下のデータを含める必要があります。gcloud CLI を使用して作成されたリクエストでは、1 つのファイルにつき 100 個以下のインスタンスを処理できます。それより多くのインスタンスに対して同時に予測を取得するには、バッチ予測を使用します。

予測用のモデルを AI Platform Prediction にデプロイする前に、モデルのサイズを縮小してください。

詳細については、オンライン予測のトラブルシューティングをご覧ください。

次のステップ