オンライン予測の取得

AI Platform のオンライン予測は、可能な限り少ないレイテンシで、ホスト対象モデルによってデータを実行するように最適化されたサービスです。小さなデータのバッチをサービスに送信すると、レスポンスで予測が返されます。

オンライン予測 vs バッチ予測、または予測コンセプトの概要をご覧ください。

始める前に

予測をリクエストするには、まず次のことを行う必要があります。

  • 予測のグラフを最適化して、SavedModel のファイルサイズが、AI Platform のデフォルトの上限である 250 MB を下回るようにします。

  • 入力データがオンライン予測用の正しい形式であることを確認します。

リージョン

AI Platform のオンライン予測は、現在次のリージョンで利用できます。

  • us-central1
  • us-east1
  • us-east4
  • asia-northeast1
  • europe-west1

AI Platform のトレーニングと予測サービスで利用可能なリージョンの詳細については、リージョンのガイドをご覧ください。

モデルとバージョンの作成

モデルとバージョンのリソースを作成するときに、オンライン予測の実行方法について重要な決定を行います。

作成するリソース リソース作成時に決定すべきこと
モデル 予測を実行するリージョン
モデル オンライン予測のロギングを有効にする
バージョン 使用するランタイム バージョン
バージョン 使用する Python のバージョン
バージョン オンライン予測で使用するマシンタイプ

モデルまたはバージョンを作成した後は、上記の設定を更新できません。これらの設定を変更する必要がある場合は、新しい設定を使用して新しいモデルまたはバージョン リソースを作成し、モデルを再デプロイしてください。

オンライン予測に使用できるマシンタイプ

現在、オンライン予測ではシングルコア CPU とクアッドコア CPU をサポートしています。他のハードウェアのアルファ版プログラムへの参加を検討されている場合は、AI Platform のフィードバックまでお問い合わせください。

名前 コア数 RAM(GB)
mls1-c1-m2(デフォルト) 1 2
mls1-c4-m2(ベータ版) 4 2

オンライン予測でクアッドコア CPU を使用する場合は、AI Platform モデルバージョンを作成するときにマシンタイプ mls1-c4-m2 を指定します。

gcloud

クアッドコア CPU を指定するには、gcloud beta コンポーネントを使用して、オプションのフラグ --machine-type "mls1-c4-m2" をコマンドに追加します。次に例を示します。

  gcloud components install beta

  gcloud beta ai-platform versions create "[YOUR_VERSION_NAME]" \
      --model "[YOUR_MODEL_NAME]" \
      --origin "[YOUR_GCS_PATH_TO_MODEL_DIRECTORY]" \
      --runtime-version "1.14" \
      --python-version "3.5" \
      --machine-type "mls1-c4-m2"

Python

この例では、Python 用 Google API クライアント ライブラリを使用します。モデル バージョンの作成方法について詳しくは、こちらをご覧ください。

クアッドコア CPU を指定するには、バージョン作成リクエストに渡す requestDict に、オプションのエントリ 'machineType': 'mls1-c4-m2' を追加します。次に例を示します。

   requestDict = {'name': '[YOUR_VERSION_NAME]',
      'description': '[YOUR_VERSION_DESCRIPTION]',
      'deploymentUri': '[YOUR_GCS_PATH_TO_MODEL_DIRECTORY]',
      'runtimeVersion': '1.14',
      'pythonVersion': '3.5',
      'machineType': 'mls1-c4-m2'}

これらのマシンタイプの料金については、料金ページをご覧ください。

オンライン予測リクエストのログのリクエスト

ログには費用が発生するため、AI Platform 予測サービスのデフォルトではリクエストに関するログ情報を提供しません。1 秒あたりクエリ数(QPS)が非常に多いオンライン予測では、かなりの数のログが生成される可能性があり、その場合は Stackdriver の料金ポリシーが適用されます。

オンライン予測のロギングを有効にするには、モデルリソースの作成時にログを生成するようにモデルを構成できます。ロギングには次の 2 つのタイプがあり、それぞれ独立して有効にすることができます。

  • アクセス ロギング。各リクエストのタイムスタンプやレイテンシなどの情報が記録されます。

  • ストリーム ロギング。予測ノードからの stderr ストリームと stdout ストリームが記録され、デバッグに役立ちます。このタイプのロギングはベータ版です。

gcloud

アクセス ロギングを有効にするには、gcloud ai-platform models create コマンドを使ってモデルを作成する際に --enable-logging フラグを指定します。次に例を示します。

gcloud ai-platform models create model_name \
  --regions us-central1 \
  --enable-logging

ストリーム ロギング(ベータ版)を有効にするには、gcloud beta コンポーネントを使用し、--enable-console-logging フラグを含めます。次に例を示します。

gcloud components install beta

gcloud beta ai-platform models create model_name \
  --regions us-central1 \
  --enable-console-logging

REST API

アクセス ロギングを有効にするには、projects.models.create でモデルを作成する際に、モデルリソース内の onlinePredictionLoggingTrue に設定します。

ストリーム ロギング(ベータ版)を有効にするには、モデルリソース内の onlinePredictionConsoleLogging フィールドを True に設定します。

オンライン予測の入力フォーマット

JSON 文字列としてインスタンスをフォーマットする

オンライン予測の基本形式は、インスタンス データのテンソルのリストです。これらは、トレーニング アプリケーションで入力を構成した方法に応じて、値のプレーンリストまたは JSON オブジェクトのメンバーのいずれかになります。

次の例は、入力テンソルとインスタンス キーを示しています。

{"values": [1, 2, 3, 4], "key": 1}

次の規則に従う限り、JSON 文字列の構成は複雑になってもかまいません。

  • インスタンス データの最上位レベルは、Key-Value ペアの辞書である JSON オブジェクトでなければならない。

  • インスタンス オブジェクト内の個々の値は、文字列、数字、またはリスト。JSON オブジェクトを埋め込むことはできない。

  • リストには、同じ型(他のリストも含む)のアイテムのみが含まれている必要がある。文字列と数値を混在させることはできない。

オンライン予測の入力インスタンスを、projects.predict 呼び出しのメッセージ本文として渡します。詳しくは、リクエスト本文のフォーマット要件をご覧ください。

gcloud

  1. 入力ファイルがテキスト ファイルであり、各行に 1 つのインスタンスが JSON オブジェクトとして記述されていることを確認します。

    {"values": [1, 2, 3, 4], "key": 1}
    {"values": [5, 6, 7, 8], "key": 2}
    

REST API

  1. 各インスタンスをリスト内のアイテムにし、リストメンバーに instances という名前を付けます。

    {"instances": [
      {"values": [1, 2, 3, 4], "key": 1},
      {"values": [5, 6, 7, 8], "key": 2}
    ]}
    

予測入力でのバイナリデータ

バイナリデータは、JSON がサポートする UTF-8 エンコード文字列としてフォーマットできません。入力にバイナリデータがある場合は、base64 エンコーディングで表す必要があります。次の特殊なフォーマットが必要です。

  • エンコードされた文字列は、b64 という 1 つのキーを持つ JSON オブジェクトとしてフォーマットする必要があります。次の Python 2.7 の例では、base64 ライブラリを使用して生の JPEG データのバッファをエンコードし、インスタンスを作成しています。

    {"image_bytes": {"b64": base64.b64encode(jpeg_data)}}
    

    Python 3.5 では、Base64 エンコードによりバイト シーケンスが出力されます。この出力を文字列に変換して JSON にシリアル化できるようにします。

    {'image_bytes': {'b64': base64.b64encode(jpeg_data).decode()}}
    
  • TensorFlow モデルのコードでは、入力 / 出力テンソルのエイリアスの名前を「_bytes」で終わらせる必要があります。

予測のリクエスト

オンライン予測をリクエストするには、予測リクエストで入力データ インスタンスを JSON 文字列として送信します。リクエストとレスポンスの本文のフォーマットについては、予測リクエストの詳細をご覧ください。

モデルのバージョンを指定しない場合、モデルのデフォルトのバージョンが予測リクエストで使用されます。

gcloud

  1. 特定のモデル バージョンを指定する場合には、バージョン値などのパラメータを格納する環境変数を作成します。

    MODEL_NAME="[YOUR-MODEL-NAME]"
    INPUT_DATA_FILE="instances.json"
    VERSION_NAME="[YOUR-VERSION-NAME]"
    
  2. gcloud ai-platform predict を使用して、デプロイ済みモデルにインスタンスを送信します。なお、--version は任意です。

    gcloud ai-platform predict --model $MODEL_NAME  \
                       --version $VERSION_NAME \
                       --json-instances $INPUT_DATA_FILE
    
  3. gcloud ツールはレスポンスを解析し、人が読める形式で予測をターミナルに出力します。predict コマンドで --format フラグを使用すると、JSON や CSV など別の出力形式を指定できます。使用可能な出力形式をご覧ください。

Python

Python 用 Google API クライアント ライブラリを使用すると、HTTP リクエストを手動で作成しなくても、AI Platform のトレーニング API と Prediction API を呼び出すことができます。次のサンプルコードを実行する前に、認証を設定する必要があります。

def predict_json(project, model, instances, version=None):
    """Send json data to a deployed model for prediction.

    Args:
        project (str): project where the AI Platform Model is deployed.
        model (str): model name.
        instances ([Mapping[str: Any]]): Keys should be the names of Tensors
            your deployed model expects as inputs. Values should be datatypes
            convertible to Tensors, or (potentially nested) lists of datatypes
            convertible to tensors.
        version: str, version of the model to target.
    Returns:
        Mapping[str: any]: dictionary of prediction results defined by the
            model.
    """
    # Create the AI Platform service object.
    # To authenticate set the environment variable
    # GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
    service = googleapiclient.discovery.build('ml', 'v1')
    name = 'projects/{}/models/{}'.format(project, model)

    if version is not None:
        name += '/versions/{}'.format(version)

    response = service.projects().predict(
        name=name,
        body={'instances': instances}
    ).execute()

    if 'error' in response:
        raise RuntimeError(response['error'])

    return response['predictions']

Java

Java 用 Google API クライアント ライブラリを使用すると、HTTP リクエストを手動で作成しなくても、AI Platform のトレーニング API と予測 API を呼び出すことができます。次のサンプルコードを実行する前に、認証を設定する必要があります。

/*
 * Copyright 2017 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import com.google.api.client.googleapis.auth.oauth2.GoogleCredential;
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport;
import com.google.api.client.http.FileContent;
import com.google.api.client.http.GenericUrl;
import com.google.api.client.http.HttpContent;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpRequestFactory;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.UriTemplate;
import com.google.api.client.json.JsonFactory;
import com.google.api.client.json.jackson2.JacksonFactory;
import com.google.api.services.discovery.Discovery;
import com.google.api.services.discovery.model.JsonSchema;
import com.google.api.services.discovery.model.RestDescription;
import com.google.api.services.discovery.model.RestMethod;
import java.io.File;

/*
 * Sample code for sending an online prediction request to Cloud Machine Learning Engine.
 */

public class OnlinePredictionSample {
  public static void main(String[] args) throws Exception {
    HttpTransport httpTransport = GoogleNetHttpTransport.newTrustedTransport();
    JsonFactory jsonFactory = JacksonFactory.getDefaultInstance();
    Discovery discovery = new Discovery.Builder(httpTransport, jsonFactory, null).build();

    RestDescription api = discovery.apis().getRest("ml", "v1").execute();
    RestMethod method = api.getResources().get("projects").getMethods().get("predict");

    JsonSchema param = new JsonSchema();
    String projectId = "YOUR_PROJECT_ID";
    // You should have already deployed a model and a version.
    // For reference, see https://cloud.google.com/ml-engine/docs/deploying-models.
    String modelId = "YOUR_MODEL_ID";
    String versionId = "YOUR_VERSION_ID";
    param.set(
        "name", String.format("projects/%s/models/%s/versions/%s", projectId, modelId, versionId));

    GenericUrl url =
        new GenericUrl(UriTemplate.expand(api.getBaseUrl() + method.getPath(), param, true));
    System.out.println(url);

    String contentType = "application/json";
    File requestBodyFile = new File("input.txt");
    HttpContent content = new FileContent(contentType, requestBodyFile);
    System.out.println(content.getLength());

    GoogleCredential credential = GoogleCredential.getApplicationDefault();
    HttpRequestFactory requestFactory = httpTransport.createRequestFactory(credential);
    HttpRequest request = requestFactory.buildRequest(method.getHttpMethod(), url, content);

    String response = request.execute().parseAsString();
    System.out.println(response);
  }
}

オンライン予測のトラブルシューティング

オンライン予測の一般的なエラーとしては、次のものがあります。

  • メモリ不足エラー
  • 入力データのフォーマットが正しくない

予測モデルを AI Platform にデプロイする前に、モデルのサイズを縮小してください

詳細については、オンライン予測のトラブルシューティングをご覧ください。

次のステップ

このページは役立ちましたか?評価をお願いいたします。

フィードバックを送信...

TensorFlow 用 AI Platform