予測を直ちに必要としない場合、または予測を取得する多数のインスタンスがある場合は、バッチ予測サービスを使用できます。このページでは、AI Platform Prediction のバッチ予測ジョブを開始する方法について説明します。AI Platform Prediction は、TensorFlow モデルからのバッチ予測の取得のみをサポートします。
オンライン予測 vs バッチ予測、または予測コンセプトの概要をご覧ください。
始める前に
予測をリクエストするには、まず次のことを行う必要があります。
プロジェクトからアクセス可能な Cloud Storage の場所にモデルリソースとバージョン リソースを作成するか、TensorFlow SavedModel を配置します。
- バッチ予測にバージョン リソースを使用する場合は、
mls1-c1-m2
マシンタイプのバージョンを作成する必要があります。
- バッチ予測にバージョン リソースを使用する場合は、
次のファイルの保存先として、プロジェクトからアクセス可能な Cloud Storage の場所を設定します。
入力データファイル。複数の場所を使用できます。ただし、プロジェクトにはそれらの場所から読み取るための権限が必要です。
出力ファイル。出力パスは 1 つしか指定できません。プロジェクトにはそこにデータを書き込むための権限が必要です。
入力ファイルがバッチ予測用の正しい形式であることを確認します。
バッチ予測ジョブの構成
バッチ予測ジョブを開始するには、いくつかの設定データを収集する必要があります。これは、API を直接呼び出すときに使用する PredictionInput オブジェクトに含まれるデータと同じです。
- データ形式
入力ファイルに使用する入力形式のタイプ。特定のジョブのすべての入力ファイルは、同じデータ形式を使用する必要があります。次のいずれかの値に設定します。
- JSON
入力ファイルはプレーン テキストで、各行にインスタンスがあります。これは、予測のコンセプト ページで説明されている形式です。
- TF_RECORD
入力ファイルは、TensorFlow TFRecords 形式を使用します。
- TF_RECORD_GZIP
入力ファイルは、GZIP で圧縮された TFRecords ファイルです。
- 入力パス
入力データファイルの URI。Google Cloud Storage のロケーション内の URI でなければなりません。次のように指定できます。
特定のファイルへのパス:
'gs://path/to/my/input/file.json'
。1 つのアスタリスク ワイルドカードを使用したディレクトリ パス。ワイルドカードで、そのディレクトリ内のすべてのファイルを表します。
'gs://path/to/my/input/*'
末尾にアスタリスク ワイルドカードを 1 つ付けた、部分ファイル名へのパス。ワイルドカードで、指定されたシーケンスで始まるすべてのファイルを表します。
'gs://path/to/my/input/file*'
複数の URI を組み合わせることができます。Python では、リストを作成します。Google Cloud CLI を使用する場合や、API を直接呼び出す場合は、複数の URI をカンマ区切りのリストで指定できます。URL の間にはスペースを入れません。これは
--input-paths
フラグの正しい形式です。--input-paths gs://a/directory/of/files/*,gs://a/single/specific/file.json,gs://a/file/template/data*
- 出力パス
予測サービスの結果の保存先として使用する、Cloud Storage の場所へのパス。プロジェクトには、この場所に書き込む権限が必要です。
- モデル名とバージョン名
予測を取得するモデルの名前、およびオプションでバージョン。バージョンを指定しない場合、モデルのデフォルトのバージョンが使用されます。バッチ予測では、バージョンは
mls1-c1-m2
マシンタイプを使用する必要があります。モデル URI(次のセクションで説明)を指定する場合は、このようなフィールドを省略します。
- モデル URI
使用する SavedModel の URI を指定すると、AI Platform Prediction にデプロイされていないモデルから予測を取得できます。SavedModel は Cloud Storage に保存されている必要があります。
要約すると、バッチ予測に使用するモデルを指定するには、3 つのオプションがあります。次を使用できます。
モデルのデフォルトのバージョンを使用する場合のモデル名のみ。
特定のモデル バージョンを使用する場合のモデル名とバージョン名。
Cloud Storage 上に存在し、AI Platform Prediction にデプロイされていない SavedModel を使用する場合のモデル URI。
- リージョン
ジョブを実行する Google Compute Engine リージョン。最高のパフォーマンスを得るには、予測ジョブを実行して、入力データと出力データを同じリージョンに格納する必要があります(特にきわめて大きなデータセットの場合)。AI Platform Prediction のバッチ予測は、以下のリージョンで利用可能です。
- us-central1 - us-east1 - europe-west1 - asia-east1
モデル トレーニングやオンライン予測などの AI Platform Prediction サービスで利用可能なリージョンの詳細については、リージョンのガイドをご覧ください。
- ジョブ名
ジョブの名前。次の条件が適用されます。
- 大文字と小文字が混在した(大文字と小文字が区別される)文字、数字、アンダースコアのみが含まれること。
- 文字から始まること。
- 128 文字以内であること。
- これまでプロジェクトで使用されていたすべてのトレーニング ジョブとバッチ予測ジョブの名前と重複していないこと。ジョブが成功したか、どのステータスであるかには関係なく、これにはプロジェクトで作成したすべてのジョブが含まれます。
- バッチサイズ(省略可)
バッチあたりのレコード数。このサービスは、モデルを呼び出す前に、
batch_size
の数のレコードをメモリにバッファします。指定しない場合のデフォルトは 64 です。- ラベル(省略可)
ジョブにラベルを追加して、リソースを表示するときやモニタリングするときにジョブをカテゴリに分類したり並べ替えたりできます。たとえば、チーム(
engineering
やresearch
などのラベルを追加)や開発フェーズ(prod
やtest
)に基づいてジョブを並べ替えることが可能です。予測ジョブにラベルを追加するには、KEY=VALUE
ペアのリストを指定します。- 最大ワーカー数(オプション)
このジョブの処理クラスタで使用する予測ノードの最大数。これは、バッチ予測の自動スケーリング機能に上限を設定する方法です。値を指定しない場合、デフォルトは 10 になります。指定した値にかかわらず、スケーリングは予測ノードの割り当てによって制限されます。
- ランタイム バージョン(オプション)
ジョブに使用する AI Platform Prediction のバージョン。このオプションにより、AI Platform Prediction にデプロイされていないモデルで使用するランタイム バージョンを指定できます。デプロイされたモデル バージョンに対しては、常にこの値を省略する必要があります。この値は、モデル バージョンのデプロイ時に指定された同じバージョンを使用するようにサービスに通知します。
- シグネチャ名(オプション)
保存したモデルに複数のシグネチャ名がある場合は、このオプションを使用して TensorFlow のカスタム シグネチャ名を指定します。これにより、TensorFlow SavedModel で定義された代替入出力マップを選択できます。シグネチャの使用方法や、カスタムモデルの出力を指定する方法については、SavedModel に関する TensorFlow ドキュメントをご覧ください。デフォルトは DEFAULT_SERVING_SIGNATURE_DEF_KEY で、その値は
serving_default
です。
次の例では、構成データを保持する変数を定義しています。
gcloud
gcloud コマンドライン ツールを使用してジョブを開始するときに、変数を作成する必要はありません。ただし、作成すると、ジョブの送信コマンドの入力と読み取りがはるかに容易になります。
DATA_FORMAT="text" # JSON data format
INPUT_PATHS='gs://path/to/your/input/data/*'
OUTPUT_PATH='gs://your/desired/output/location'
MODEL_NAME='census'
VERSION_NAME='v1'
REGION='us-east1'
now=$(date +"%Y%m%d_%H%M%S")
JOB_NAME="census_batch_predict_$now"
MAX_WORKER_COUNT="20"
BATCH_SIZE="32"
LABELS="team=engineering,phase=test,owner=sara"
Python
Python 用 Google API クライアント ライブラリを使用する場合、Python 辞書を使用して Job リソースと PredictionInput リソースを表すことができます。
AI Platform Prediction の REST API で使用されている構文で、プロジェクト名とモデル名またはバージョン名を記述します。
- project_name -> 'projects/project_name'
- model_name -> 'projects/project_name/models/model_name'
- version_name -> 'projects/project_name/models/model_name/versions/version_name'
Job リソースの辞書を作成し、2 つの項目を入力します。
値として使用するジョブ名を含む、
'jobId'
というキー。もう 1 つの辞書オブジェクトを含む
'predictionInput'
というキー。このオブジェクトには、PredictionInput のすべての必須メンバーと、使用する予定の任意のメンバーが格納されます。
次の例は、構成情報を入力変数として取り、予測リクエスト本文を返す関数を示しています。この例では、基本に加えて、プロジェクト名、モデル名、および現在の時刻に基づいてユニークなジョブ識別子も生成します。
import time import re def make_batch_job_body(project_name, input_paths, output_path, model_name, region, data_format='JSON', version_name=None, max_worker_count=None, runtime_version=None): project_id = 'projects/{}'.format(project_name) model_id = '{}/models/{}'.format(project_id, model_name) if version_name: version_id = '{}/versions/{}'.format(model_id, version_name) # Make a jobName of the format "model_name_batch_predict_YYYYMMDD_HHMMSS" timestamp = time.strftime('%Y%m%d_%H%M%S', time.gmtime()) # Make sure the project name is formatted correctly to work as the basis # of a valid job name. clean_project_name = re.sub(r'\W+', '_', project_name) job_id = '{}_{}_{}'.format(clean_project_name, model_name, timestamp) # Start building the request dictionary with required information. body = {'jobId': job_id, 'predictionInput': { 'dataFormat': data_format, 'inputPaths': input_paths, 'outputPath': output_path, 'region': region}} # Use the version if present, the model (its default version) if not. if version_name: body['predictionInput']['versionName'] = version_id else: body['predictionInput']['modelName'] = model_id # Only include a maximum number of workers or a runtime version if specified. # Otherwise let the service use its defaults. if max_worker_count: body['predictionInput']['maxWorkerCount'] = max_worker_count if runtime_version: body['predictionInput']['runtimeVersion'] = runtime_version return body
バッチ予測ジョブの送信
ジョブを送信するには、単純な projects.jobs.create 呼び出しを使用できます。またはコマンドライン ツールでこれに相当する gcloud ai-platform jobs submit prediction を使用できます。
gcloud
次の例では、前のセクションで定義した変数を使用してバッチ予測を開始します。
gcloud ai-platform jobs submit prediction $JOB_NAME \
--model $MODEL_NAME \
--input-paths $INPUT_PATHS \
--output-path $OUTPUT_PATH \
--region $REGION \
--data-format $DATA_FORMAT
Python
Python 用 Google API クライアント ライブラリを使用してバッチ予測ジョブを開始する手順は、他のクライアント SDK を使用する場合のパターンと同様です。
呼び出しに使用するリクエスト本文を準備します(これは前のセクションで示しています)。
ml.projects.jobs.create を呼び出してリクエストを作成します。
リクエストで呼び出しを実行し、レスポンスを取得して、HTTP エラーがないかチェックします。
レスポンスを辞書として使用し、Job リソースから値を取得します。
Python 用 Google API クライアント ライブラリを使用すると、HTTP リクエストを手動で作成しなくても、AI Platform Training API と Prediction API を呼び出すことができます。次のサンプルコードを実行する前に、認証を設定する必要があります。
import googleapiclient.discovery as discovery
project_id = 'projects/{}'.format(project_name)
ml = discovery.build('ml', 'v1')
request = ml.projects().jobs().create(parent=project_id,
body=batch_predict_body)
try:
response = request.execute()
print('Job requested.')
# The state returned will almost always be QUEUED.
print('state : {}'.format(response['state']))
except errors.HttpError as err:
# Something went wrong, print out some information.
print('There was an error getting the prediction results.' +
'Check the details:')
print(err._get_reason())
バッチ予測ジョブのモニタリング
バッチ予測ジョブは完了するまで長い時間がかかる可能性があります。ジョブの進行状況は Google Cloud コンソールで確認できます。
Google Cloud Console で AI Platform Prediction の [ジョブ] ページに移動します。
[ジョブ ID] リストでジョブの名前をクリックします。[ジョブの詳細] ページが開きます。
現在のステータスが、ジョブ名とともにページ上部に表示されます。
詳細を確認する場合は、[ログを表示] をクリックして、Cloud Logging にジョブのエントリを表示します。
バッチ予測ジョブの進行状況を追跡する他の方法もあります。そのような方法ではトレーニング ジョブのモニタリングと同じパターンに従います。詳細については、トレーニング ジョブのモニタリング方法を説明しているページを参照してください。予測ジョブを処理するには、そこの手順をわずかに調整する必要がありますが、メカニズムは同じです。
予測結果の取得
このサービスは指定された Cloud Storage の場所に予測を書き込みます。関心のある結果が含まれる可能性のあるファイル出力には、2 つのタイプがあります。
prediction.errors_stats-NNNNN-of-NNNNN
という名前のファイルには、ジョブ中に発生した問題に関する情報が含まれます。prediction.results-NNNNN-of-NNNNN
という名前の JSON Lines ファイルには、モデルの出力で定義された予測自体が含まれます。
ファイル名には、検索するファイルの総数をキャプチャするインデックス番号(上記の例で各桁の 'N')が含まれています。たとえば、6 つの結果ファイルがあるジョブには、prediction.results-00000-of-00006
から prediction.results-00005-of-00006
が含まれます。
各予測ファイルの各行は、1 つの予測結果を表す JSON オブジェクトです。任意のテキスト エディタで予測ファイルを開くことができます。コマンドラインで簡単に見るには、gcloud storage cat
を使うことができます。
gcloud storage cat $OUTPUT_PATH/prediction.results-NNNNN-of-NNNNN|less
1 つの入力ファイルしか使用しない場合でも、予測結果は通常、入力インスタンスと同じ順序で出力されません。インスタンスに対応する予測を見つけるには、インスタンス キーを照合します。
次のステップ
- オンライン予測を使用する。
- 予測プロセスの詳細を取得する。
- オンライン予測をリクエストするときに発生する問題のトラブルシューティングを行う。
- ジョブの整理にラベルを使用する方法を学ぶ。