オンライン予測の取得

Cloud Machine Learning Engine のオンライン予測は、可能な限り少ないレイテンシで、ホスト対象モデルによってデータを実行するように最適化されたサービスです。小さなデータのバッチをサービスに送信すると、レスポンスで予測が返されます。

オンライン予測 vs バッチ予測を参照するか、予測コンセプトの概要をご覧ください。

始める前に

予測をリクエストするには、まず次のことを行う必要があります。

  • 予測のためにグラフを改善して、SavedModel のファイルサイズが Cloud ML Engine のデフォルト制限の 250 MB 以下になるようにします。

  • 入力データがオンライン予測用の正しい形式であることを確認します。

リージョン

Cloud ML Engine のオンライン予測は、現在次のリージョンで利用できます。

  • us-central1
  • us-east1
  • us-east4
  • asia-northeast1
  • europe-west1

Cloud ML Engine のトレーニング サービスや予測サービスで利用可能なリージョンを完全に理解するには、リージョン ガイドをお読みください。

モデルとバージョンの作成

モデルとバージョンのリソースを作成するときに、オンライン予測の実行方法について重要な決定を行います。

作成するリソース リソース作成時に決定すべきこと
モデル 予測を実行するリージョン
モデル オンライン予測のロギングを有効にする
バージョン 使用するランタイム バージョン
バージョン 使用する Python のバージョン
バージョン オンライン予測で使用するマシンタイプ

モデルまたはバージョンを作成した後は、上記の設定を更新できません。これらの設定を変更する必要がある場合は、新しい設定を使用して新しいモデルまたはバージョン リソースを作成し、モデルを再デプロイしてください。

オンライン予測に使用できるマシンタイプ

現在、オンライン予測ではシングルコア CPU とクアッドコア CPU をサポートしています。他のハードウェアのアルファ版プログラムへの参加をご希望の方は、Cloud ML Engine フィードバックでお問い合わせください。

名前 コア数 RAM(GB)
mls1-c1-m2(デフォルト) 1 2
mls1-c4-m2(ベータ版) 4 2

オンライン予測でクアッドコア CPU を使用するには、Cloud ML Engine モデル バージョンを作成するときに、マシンタイプとして mls1-c4-m2 を指定します。

gcloud

クアッドコア CPU を指定するには、gcloud beta コンポーネントを使用して、オプションのフラグ --machine-type "mls1-c4-m2" をコマンドに追加します。次に例を示します。

  gcloud components install beta

  gcloud beta ml-engine versions create "[YOUR_VERSION_NAME]" \
      --model "[YOUR_MODEL_NAME]" \
      --origin "[YOUR_GCS_PATH_TO_MODEL_DIRECTORY]"\
      --runtime-version "1.13"
      --python-version "3.5"
      --machine-type "mls1-c4-m2"

Python

このサンプルでは、Python 用 Google API クライアント ライブラリを使用します。モデル バージョンの作成方法について詳しくは、こちらをご覧ください。

クアッドコア CPU を指定するには、バージョン作成リクエストに渡す requestDict に、オプションのエントリ 'machineType': 'mls1-c4-m2' を追加します。次に例を示します。

   requestDict = {'name': '[YOUR_VERSION_NAME]',
      'description': '[YOUR_VERSION_DESCRIPTION]',
      'deploymentUri': '[YOUR_GCS_PATH_TO_MODEL_DIRECTORY]',
      'runtimeVersion': '1.13',
      'pythonVersion': '3.5',
      'machineType': 'mls1-c4-m2'}

これらのマシンタイプの料金については、料金ページをご覧ください。

オンライン予測リクエストのログのリクエスト

ログは課金対象のため、Cloud ML Engine 予測サービスは、デフォルトでリクエストに関するログ情報を提供しません。1 秒あたりのクエリ数(QPS)が非常に多いオンライン予測の場合、相当数のログが生成される可能性があり、Stackdriver の料金ポリシーが適用されます。

オンライン予測のロギングを有効にするには、モデルリソースの作成時にログを生成するようにモデルを構成します。

gcloud

gcloud ml-engine models create コマンドでモデルを作成するときに、--enable-logging フラグを指定します。

Python

projects.models.create でモデルを作成するときに、モデルリソースで onlinePredictionLoggingTrue に設定します。

オンライン予測の入力フォーマット

JSON 文字列としてインスタンスをフォーマットする

オンライン予測の基本形式は、インスタンス データのテンソルのリストです。これらは、トレーニング アプリケーションで入力を構成した方法に応じて、値のプレーンリストまたは JSON オブジェクトのメンバーのいずれかになります。

次の例は、入力テンソルとインスタンス キーを示しています。

{"values": [1, 2, 3, 4], "key": 1}

次の規則に従う限り、JSON 文字列の構成は複雑になってもかまいません。

  • インスタンス データの最上位レベルは、Key-Value ペアの辞書である JSON オブジェクトでなければならない。

  • インスタンス オブジェクト内の個々の値は、文字列、数字、またはリスト。JSON オブジェクトを埋め込むことはできない。

  • リストには、同じ型(他のリストも含む)のアイテムのみが含まれている必要がある。文字列と数値を混在させることはできない。

オンライン予測の入力インスタンスは、projects.predict 呼び出しのメッセージ本文として渡します。

gcloud

  1. 入力ファイルがテキスト ファイルであり、各行に 1 つのインスタンスが JSON オブジェクトとして記述されていることを確認します。

    {"values": [1, 2, 3, 4], "key": 1}
    {"values": [5, 6, 7, 8], "key": 2}
    

REST API

  1. 各インスタンスをリスト内のアイテムにし、リストメンバーに instances という名前を付けます。

    {"instances": [{"values": [1, 2, 3, 4], "key": 1}]}
    

予測入力でのバイナリデータ

バイナリデータは、JSON がサポートする UTF-8 エンコード文字列としてフォーマットできません。入力にバイナリデータがある場合は、base64 エンコーディングで表す必要があります。次の特殊なフォーマットが必要です。

  • エンコードされた文字列は、b64 という 1 つのキーを持つ JSON オブジェクトとしてフォーマットする必要があります。次の Python 2.7 の例では、base64 ライブラリを使用して生の JPEG データのバッファをエンコードし、インスタンスを作成しています。

    {"image_bytes": {"b64": base64.b64encode(jpeg_data)}}
    

    Python 3.5 では、Base64 エンコードによりバイト シーケンスが出力されます。この出力を文字列に変換して JSON にシリアル化できるようにします。

    {'image_bytes': {'b64': base64.b64encode(jpeg_data).decode()}}
    
  • TensorFlow モデルのコードでは、入力 / 出力テンソルのエイリアスの名前を「_bytes」で終わらせる必要があります。

予測のリクエスト

オンライン予測をリクエストするには、予測リクエストで入力データ インスタンスを JSON 文字列として送信します。リクエストとレスポンスの本文のフォーマットについては、予測リクエストの詳細をご覧ください。

モデルのバージョンを指定しない場合、モデルのデフォルトのバージョンが予測リクエストで使用されます。

gcloud

  1. モデル バージョンを指定する場合は、バージョン値などのパラメータを保持する環境変数を作成します。

    MODEL_NAME="[YOUR-MODEL-NAME]"
    INPUT_DATA_FILE="instances.json"
    VERSION_NAME="[YOUR-VERSION-NAME]"
    
  2. gcloud ml-engine predict を使用して、インスタンスをデプロイ済みモデルに送信します。--version は任意です。

    gcloud ml-engine predict --model $MODEL_NAME  \
                       --version $VERSION_NAME \
                       --json-instances $INPUT_DATA_FILE
    
  3. gcloud ツールはレスポンスを解析し、人が読める形式で予測を端末に出力します。predict コマンドで --format フラグを使用すると、JSON や CSV など別の出力形式を指定できます。使用可能な出力形式をご覧ください。

Python

Python 用 Google API クライアント ライブラリを使用すると、HTTP リクエストを手動で構築することなく、Cloud Machine Learning Engine API を呼び出せます。次のサンプルコードを実行する前に、認証を設定する必要があります。

def predict_json(project, model, instances, version=None):
    """Send json data to a deployed model for prediction.

    Args:
        project (str): project where the Cloud ML Engine Model is deployed.
        model (str): model name.
        instances ([Mapping[str: Any]]): Keys should be the names of Tensors
            your deployed model expects as inputs. Values should be datatypes
            convertible to Tensors, or (potentially nested) lists of datatypes
            convertible to tensors.
        version: str, version of the model to target.
    Returns:
        Mapping[str: any]: dictionary of prediction results defined by the
            model.
    """
    # Create the ML Engine service object.
    # To authenticate set the environment variable
    # GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
    service = googleapiclient.discovery.build('ml', 'v1')
    name = 'projects/{}/models/{}'.format(project, model)

    if version is not None:
        name += '/versions/{}'.format(version)

    response = service.projects().predict(
        name=name,
        body={'instances': instances}
    ).execute()

    if 'error' in response:
        raise RuntimeError(response['error'])

    return response['predictions']

Java

Java 用 Google API クライアント ライブラリを使用すると、HTTP リクエストを手動で作成することなく、Cloud Machine Learning Engine API を呼び出せます。次のサンプルコードを実行する前に、認証を設定する必要があります。

/*
 * Copyright 2017 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import com.google.api.client.googleapis.auth.oauth2.GoogleCredential;
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport;
import com.google.api.client.http.FileContent;
import com.google.api.client.http.GenericUrl;
import com.google.api.client.http.HttpContent;
import com.google.api.client.http.HttpRequest;
import com.google.api.client.http.HttpRequestFactory;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.UriTemplate;
import com.google.api.client.json.JsonFactory;
import com.google.api.client.json.jackson2.JacksonFactory;
import com.google.api.services.discovery.Discovery;
import com.google.api.services.discovery.model.JsonSchema;
import com.google.api.services.discovery.model.RestDescription;
import com.google.api.services.discovery.model.RestMethod;
import java.io.File;

/*
 * Sample code for doing Cloud Machine Learning Engine online prediction in Java.
 */

public class OnlinePredictionSample {
  public static void main(String[] args) throws Exception {
    HttpTransport httpTransport = GoogleNetHttpTransport.newTrustedTransport();
    JsonFactory jsonFactory = JacksonFactory.getDefaultInstance();
    Discovery discovery = new Discovery.Builder(httpTransport, jsonFactory, null).build();

    RestDescription api = discovery.apis().getRest("ml", "v1").execute();
    RestMethod method = api.getResources().get("projects").getMethods().get("predict");

    JsonSchema param = new JsonSchema();
    String projectId = "YOUR_PROJECT_ID";
    // You should have already deployed a model and a version.
    // For reference, see https://cloud.google.com/ml-engine/docs/deploying-models.
    String modelId = "YOUR_MODEL_ID";
    String versionId = "YOUR_VERSION_ID";
    param.set(
        "name", String.format("projects/%s/models/%s/versions/%s", projectId, modelId, versionId));

    GenericUrl url =
        new GenericUrl(UriTemplate.expand(api.getBaseUrl() + method.getPath(), param, true));
    System.out.println(url);

    String contentType = "application/json";
    File requestBodyFile = new File("input.txt");
    HttpContent content = new FileContent(contentType, requestBodyFile);
    System.out.println(content.getLength());

    GoogleCredential credential = GoogleCredential.getApplicationDefault();
    HttpRequestFactory requestFactory = httpTransport.createRequestFactory(credential);
    HttpRequest request = requestFactory.buildRequest(method.getHttpMethod(), url, content);

    String response = request.execute().parseAsString();
    System.out.println(response);
  }
}

オンライン予測のトラブルシューティング

オンライン予測の一般的なエラーとしては、次のものがあります。

  • メモリ不足エラー
  • 入力データのフォーマットが正しくない

予測のために Cloud ML Engine にデプロイする前に、モデルサイズの縮小を試してください。

詳細については、オンライン予測のトラブルシューティングをご覧ください。

次のステップ

このページは役立ちましたか?評価をお願いいたします。

フィードバックを送信...

TensorFlow 用 Cloud ML Engine