カスタム トレーニング済みモデルからバッチ予測を取得する

このページでは、Google Cloud コンソールまたは Vertex AI API を使用して、カスタム トレーニング済みモデルからバッチ予測を取得する方法について説明します。

バッチ予測リクエストを行うには、入力ソースと出力先(Vertex AI がバッチ予測結果を保存する Cloud Storage または BigQuery)を指定します。

処理時間を最小限に抑えるには、入力と出力のロケーションを同じリージョンまたはマルチリージョンにする必要があります。たとえば、入力が us-central1 にある場合、出力は us-central1 または US に存在できますが、europe-west4 には存在できません。詳細については、Cloud Storage のロケーションBigQuery のロケーションをご覧ください。

また、入力と出力は、モデルと同じリージョンまたはマルチリージョンに存在する必要があります。

  • BigQuery ML モデルは Vertex AI Model Registry に登録する必要があります。
  • BigQuery テーブルを入力として使用するには、Vertex AI API を使用して InstanceConfig.instanceType"object" に設定する必要があります。

入力データの要件

一括リクエストの入力では、予測用のモデルに送信するアイテムを指定します。次の入力形式がサポートされています。

JSON Lines

JSON Lines ファイルを使用して、予測を行う入力インスタンスのリストを指定します。ファイルを Cloud Storage バケットに保存します。

例 1

次の例は、各行に配列が含まれている JSON Lines ファイルを示しています。

[1, 2, 3, 4]
[5, 6, 7, 8]

HTTP リクエスト本文で予測コンテナに送信される内容は次のとおりです。

その他すべてのコンテナ

{"instances": [ [1, 2, 3, 4], [5, 6, 7, 8] ]}

PyTorch コンテナ

{"instances": [
{ "data": [1, 2, 3, 4] },
{ "data": [5, 6, 7, 8] } ]}

例 2

次の例は、各行にオブジェクトが含まれている JSON Lines ファイルを示しています。

{ "values": [1, 2, 3, 4], "key": 1 }
{ "values": [5, 6, 7, 8], "key": 2 }

HTTP リクエスト本文で予測コンテナに送信される内容は次のとおりです。すべてのコンテナに同じリクエスト本文が送信されます。

{"instances": [
  { "values": [1, 2, 3, 4], "key": 1 },
  { "values": [5, 6, 7, 8], "key": 2 }
]}

例 3

PyTorch のビルド済みコンテナの場合は、TorchServe のデフォルト ハンドラの要件に従って各インスタンスを data フィールドにラップします。Vertex AI はインスタンスを自動的にラップしません。例:

{ "data": { "values": [1, 2, 3, 4], "key": 1 } }
{ "data": { "values": [5, 6, 7, 8], "key": 2 } }

HTTP リクエスト本文で予測コンテナに送信される内容は次のとおりです。

{"instances": [
  { "data": { "values": [1, 2, 3, 4], "key": 1 } },
  { "data": { "values": [5, 6, 7, 8], "key": 2 } }
]}

TFRecord

入力インスタンスを TFRecord 形式で保存します。必要に応じて、Gzip を使用して TFRecord ファイルを圧縮できます。TFRecord ファイルを Cloud Storage バケットに保存します。

Vertex AI は、TFRecord ファイル内の各インスタンスをバイナリとして読み取り、b64 という名前の単一キーを使用して、インスタンスを JSON オブジェクトとして base64 エンコードします。

HTTP リクエスト本文で予測コンテナに送信される内容は次のとおりです。

その他すべてのコンテナ

{"instances": [
{ "b64": "b64EncodedASCIIString" },
{ "b64": "b64EncodedASCIIString" } ]}

PyTorch コンテナ

{"instances": [ { "data": {"b64": "b64EncodedASCIIString" } }, { "data": {"b64": "b64EncodedASCIIString" } }
]}

予測コンテナがインスタンスのデコード方法を認識していることを確認します。

CSV

CSV ファイルで 1 行に 1 つの入力インスタンスを指定します。最初の行はヘッダー行にする必要があります。すべての文字列は二重引用符(")で囲む必要があります。Vertex AI では、改行を含むセル値は使用できません。引用符で囲まれていない値は浮動小数点数として読み取られます。

次の例は、2 つの入力インスタンスを含む CSV ファイルを示しています。

"input1","input2","input3"
0.1,1.2,"cat1"
4.0,5.0,"cat2"

HTTP リクエスト本文で予測コンテナに送信される内容は次のとおりです。

その他すべてのコンテナ

{"instances": [ [0.1,1.2,"cat1"], [4.0,5.0,"cat2"] ]}

PyTorch コンテナ

{"instances": [
{ "data": [0.1,1.2,"cat1"] },
{ "data": [4.0,5.0,"cat2"] } ]}

ファイルリスト

各行が Cloud Storage URI になっているテキスト ファイルを作成します。Vertex AI は各ファイルのコンテンツをバイナリとして読み取り、そのインスタンスを b64 という名前の単一のキーを持つ JSON オブジェクトとして base64 エンコードします。

Google Cloud コンソールを使用してバッチ予測を取得する場合は、ファイルリストを Google Cloud コンソールに直接貼り付けます。それ以外の場合は、リストを Cloud Storage バケットに保存します。

次の例は、2 つの入力インスタンスを含むファイルリストを示しています。

gs://path/to/image/image1.jpg
gs://path/to/image/image2.jpg

HTTP リクエスト本文で予測コンテナに送信される内容は次のとおりです。

その他すべてのコンテナ

{ "instances": [
{ "b64": "b64EncodedASCIIString" },
{ "b64": "b64EncodedASCIIString" } ]}

PyTorch コンテナ

{ "instances": [ { "data": { "b64": "b64EncodedASCIIString" } }, { "data": { "b64": "b64EncodedASCIIString" } }
]}

予測コンテナがインスタンスのデコード方法を認識していることを確認します。

BigQuery

BigQuery テーブルを projectId.datasetId.tableId として指定します。Vertex AI は、テーブルの各行を JSON インスタンスに変換します。

たとえば、テーブルに以下の対象が含まれているとします。

列 1 列 2 列 3
1.0 3.0 「Cat1」
2.0 4.0 「Cat2」

HTTP リクエスト本文で予測コンテナに送信される内容は次のとおりです。

その他すべてのコンテナ

{"instances": [ [1.0,3.0,"cat1"], [2.0,4.0,"cat2"] ]}

PyTorch コンテナ

{"instances": [
{ "data": [1.0,3.0,"cat1"] },
{ "data": [2.0,4.0,"cat2"] } ]}

BigQuery のデータ型がどのように JSON に変換されるかを以下に示します。

BigQuery の型 JSON 型 値の例
文字列 文字列 「abc」
整数 整数 1
浮動小数点数 浮動小数点数 1.2
数値 浮動小数点数 4925.000000000
ブール値 ブール値 true
タイムスタンプ 文字列 「2019-01-01 23:59:59.999999+00:00」
日付 文字列 「2018-12-31」
時間 文字列 「23:59:59.999999」
DateTime 文字列 「2019-01-01T00:00:00」
記録 オブジェクト { "A": 1,"B": 2}
繰り返しタイプ Array[Type] [1, 2]
ネストされたレコード オブジェクト {"A": {"a": 0}, "B": 1}

データをパーティショニングする

バッチ予測では、MapReduce を使用して入力を各レプリカにシャーディングします。MapReduce 機能を利用するには、入力がパーティショニング可能である必要があります。

Vertex AI によって、BigQueryファイルリストJSON Lines の入力が自動的にパーティショニングされます。

CSV ファイルは元来パーティショニングに適していないため、自動的パーティショニングは行われません。CSV ファイルの行は自己記述的でないほか、型付けされておらず、改行が含まれている可能性があります。スループットが重要なアプリケーションには、CSV 入力を使用しないことをおすすめします。

TFRecord 入力の場合は、インスタンスをより小さなファイルに分割し、ワイルドカード(gs://my-bucket/*.tfrecord など)を使用してファイルをジョブに渡すことで、データを手動でパーティショニングするようにしてください。ファイルの数は、指定したレプリカの数以上にする必要があります。

入力データをフィルタリングして変換する

バッチ入力のフィルタリングや変換を行うには、BatchPredictionJob リクエストで instanceConfig を指定します。

フィルタリングを使用すると、予測リクエストから入力データ内の特定のフィールドを除外するか、予測リクエストの入力データからフィールドのサブセットのみを含めることができます。予測コンテナでカスタムの前処理または後処理を実行する必要はありません。これは、入力データファイルに、キーや追加データなど、モデルに不要な列がある場合に有効です。

変換を使用すると、インスタンスを JSON array または object 形式で予測コンテナに送信できます。詳細については、instanceType をご覧ください。

たとえば、入力テーブルに以下の対象が含まれているとします。

customerId col1 col2
1001 1 2
1002 5 6

そして、以下の instanceConfig を指定します。

{
  "name": "batchJob1",
  ...
  "instanceConfig": {
    "excludedFields":["customerId"]
    "instanceType":"object"
  }
}

これで、予測リクエスト内のインスタンスが JSON オブジェクトとして送信され、customerId 列は除外されます。

{"col1":1,"col2":2}
{"col1":5,"col2":6}

次の instanceConfig を指定しても、同じ結果になります。

{
  "name": "batchJob1",
  ...
  "instanceConfig": {
    "includedFields": ["col1","col2"]
    "instanceType":"object"
  }
}

特徴フィルタの使用方法のデモについては、特徴フィルタリングを使用したカスタムモデルのバッチ予測のノートブックをご覧ください。

バッチ予測をリクエストする

バッチ予測リクエストの場合、Google Cloud コンソールまたは Vertex AI API を使用できます。送信した入力アイテム数によっては、バッチ予測タスクが完了するまでに時間がかかることがあります。

バッチ予測をリクエストすると、予測コンテナはユーザー指定のカスタム サービス アカウントとして実行されます。読み取り / 書き込みオペレーション(データソースからの予測インスタンスの読み取りや予測結果の書き込みなど)は、デフォルトで BigQuery と Cloud Storage にアクセスできる Vertex AI サービス エージェントを使用して行われます

Google Cloud コンソール

Google Cloud コンソールを使用してバッチ予測をリクエストします。

  1. Google Cloud コンソールの [Vertex AI] セクションで、[バッチ予測] ページに移動します。

[バッチ予測] ページに移動

  1. [作成] をクリックして、[新しいバッチ予測] ウィンドウを開きます。

  2. [バッチ予測の定義] で、次の手順を完了します。

    1. バッチ予測の名前を入力します。

    2. [モデル名] で、このバッチ予測に使用するモデルの名前を選択します。

    3. [ソースを選択] で、入力データに適用するソースを選択します。

      • 入力を JSON Lines、CSV、または TFRecord としてフォーマットしている場合は、[Cloud Storage 上のファイル(JSON Lines、CSV、TFRecord、TFRecord Gzip)] を選択します。次に、[転送元のパス] フィールドに入力ファイルを指定します。
      • 入力としてファイルリストを使用する場合は、[Cloud Storage 上のファイル(その他)] を選択して、次のフィールドにファイルリストを貼り付けます。
      • BigQuery 入力で、[BigQuery パス] を選択します。入力として BigQuery を選択する場合は、BigQuery を出力と Google 管理の暗号鍵として選択する必要もあります。BigQuery の入力 / 出力として、顧客管理の暗号鍵(CMEK)はサポートされていません。
    4. [宛先のパス] フィールドに、Vertex AI がバッチ予測の出力を保存する Cloud Storage ディレクトリを指定します。

    5. 必要に応じて、[このモデルの特徴アトリビューションを有効にする] をオンにすると、バッチ予測レスポンスの一部として特徴アトリビューションを取得できます。次に、[編集] をクリックして、説明の構成を行います(モデルの説明を以前に構成した場合は、説明の設定の編集は任意です。それ以外の場合は必須です)。

    6. コンピューティング ノードの数マシンタイプなどのバッチ予測ジョブのコンピューティング オプションを指定します。必要に応じて、アクセラレータ タイプアクセラレータ数も指定できます。

  3. 省略可: バッチ予測の Model Monitoring 分析はプレビュー版として提供されています。スキュー検出構成をバッチ予測ジョブに追加する方法については、前提条件をご覧ください。

    1. [このバッチ予測のモデルのモニタリングを有効にする] をクリックしてオンにします。

    2. トレーニング データソースを選択します。選択したトレーニング データソースのデータパスまたは場所を入力します。

    3. (省略可)[アラートのしきい値] で、アラートをトリガーするしきい値を指定します。

    4. [通知メール] に、モデルがアラートのしきい値を超えたときにアラートを受け取るメールアドレスを、カンマ区切り形式で 1 つ以上入力します。

    5. 省略可: 通知チャンネルの場合、モデルがアラートのしきい値を超えたときにアラートを受け取るには、Cloud Monitoring チャンネルを追加します。[通知チャンネルを管理] をクリックして、既存の Cloud Monitoring チャネルを選択するか、新しい Cloud Monitoring チャネルを作成できます。Google Cloud コンソールでは、PagerDuty、Slack、Pub/Sub の通知チャネルがサポートされています。

  4. [作成] をクリックします。

API

Vertex AI API を使用してバッチ予測リクエストを送信します。バッチ予測の取得に使用するツールに対応するタブを選択します。

REST

リクエストのデータを使用する前に、次のように置き換えます。

  • LOCATION_ID: モデルを保存し、バッチ予測ジョブを実行するリージョン。例: us-central1

  • PROJECT_ID: 実際のプロジェクト ID

  • BATCH_JOB_NAME: バッチ予測ジョブの表示名。

  • MODEL_ID: 予測に使用するモデルの ID。

  • INPUT_FORMAT: 入力データの形式: jsonlcsvtf-recordtf-record-gzipfile-list のいずれかになります。

  • INPUT_URI: 入力データの Cloud Storage URI。ワイルドカードを使用できます。

  • OUTPUT_DIRECTORY: Vertex AI が出力を保存するディレクトリの Cloud Storage URI。

  • MACHINE_TYPE: このバッチ予測ジョブに使用されるマシンリソース

    必要であれば、アクセラレータを使用するように machineSpec フィールドを構成できますが、次の例では説明していません。

  • BATCH_SIZE: 予測リクエストごとに送信するインスタンスの数。デフォルトは 64 です。バッチサイズを増やすとスループットは向上しますが、リクエストがタイムアウトする可能性もあります。

  • STARTING_REPLICA_COUNT: このバッチ予測ジョブのノード数。

HTTP メソッドと URL:

POST https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs

リクエストの本文(JSON):

{
  "displayName": "BATCH_JOB_NAME",
  "model": "projects/PROJECT_ID/locations/LOCATION_ID/models/MODEL_ID",
  "inputConfig": {
    "instancesFormat": "INPUT_FORMAT",
    "gcsSource": {
      "uris": ["INPUT_URI"],
    },
  },
  "outputConfig": {
    "predictionsFormat": "jsonl",
    "gcsDestination": {
      "outputUriPrefix": "OUTPUT_DIRECTORY",
    },
  },
  "dedicatedResources" : {
    "machineSpec" : {
      "machineType": MACHINE_TYPE
    },
    "startingReplicaCount": STARTING_REPLICA_COUNT
  },
  "manualBatchTuningParameters": {
    "batch_size": BATCH_SIZE,
  }
}

リクエストを送信するには、次のいずれかのオプションを選択します。

curl

リクエスト本文を request.json という名前のファイルに保存して、次のコマンドを実行します。

curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs"

PowerShell

リクエスト本文を request.json という名前のファイルに保存して、次のコマンドを実行します。

$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs" | Select-Object -Expand Content

次のような JSON レスポンスが返されます。

{
  "name": "projects/PROJECT_NUMBER/locations/LOCATION_ID/batchPredictionJobs/BATCH_JOB_ID",
  "displayName": "BATCH_JOB_NAME 202005291958",
  "model": "projects/PROJECT_ID/locations/LOCATION_ID/models/MODEL_ID",
  "inputConfig": {
    "instancesFormat": "jsonl",
    "gcsSource": {
      "uris": [
        "INPUT_URI"
      ]
    }
  },
  "outputConfig": {
    "predictionsFormat": "jsonl",
    "gcsDestination": {
      "outputUriPrefix": "OUTPUT_DIRECTORY"
    }
  },
  "state": "JOB_STATE_PENDING",
  "createTime": "2020-05-30T02:58:44.341643Z",
  "updateTime": "2020-05-30T02:58:44.341643Z",
}

Java

このサンプルを試す前に、Vertex AI クイックスタート: クライアント ライブラリの使用にある Java の設定手順を完了してください。詳細については、Vertex AI Java API のリファレンス ドキュメントをご覧ください。

Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。

次のサンプルでは、PREDICTIONS_FORMATjsonl に置き換えます。他のプレースホルダを置き換える方法については、このセクションの REST & CMD LINE タブをご覧ください。

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.AcceleratorType;
import com.google.cloud.aiplatform.v1.BatchDedicatedResources;
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
import com.google.cloud.aiplatform.v1.GcsDestination;
import com.google.cloud.aiplatform.v1.GcsSource;
import com.google.cloud.aiplatform.v1.JobServiceClient;
import com.google.cloud.aiplatform.v1.JobServiceSettings;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.MachineSpec;
import com.google.cloud.aiplatform.v1.ModelName;
import com.google.protobuf.Value;
import java.io.IOException;

public class CreateBatchPredictionJobSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "PROJECT";
    String displayName = "DISPLAY_NAME";
    String modelName = "MODEL_NAME";
    String instancesFormat = "INSTANCES_FORMAT";
    String gcsSourceUri = "GCS_SOURCE_URI";
    String predictionsFormat = "PREDICTIONS_FORMAT";
    String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX";
    createBatchPredictionJobSample(
        project,
        displayName,
        modelName,
        instancesFormat,
        gcsSourceUri,
        predictionsFormat,
        gcsDestinationOutputUriPrefix);
  }

  static void createBatchPredictionJobSample(
      String project,
      String displayName,
      String model,
      String instancesFormat,
      String gcsSourceUri,
      String predictionsFormat,
      String gcsDestinationOutputUriPrefix)
      throws IOException {
    JobServiceSettings settings =
        JobServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();
    String location = "us-central1";

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (JobServiceClient client = JobServiceClient.create(settings)) {

      // Passing in an empty Value object for model parameters
      Value modelParameters = ValueConverter.EMPTY_VALUE;

      GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build();
      BatchPredictionJob.InputConfig inputConfig =
          BatchPredictionJob.InputConfig.newBuilder()
              .setInstancesFormat(instancesFormat)
              .setGcsSource(gcsSource)
              .build();
      GcsDestination gcsDestination =
          GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build();
      BatchPredictionJob.OutputConfig outputConfig =
          BatchPredictionJob.OutputConfig.newBuilder()
              .setPredictionsFormat(predictionsFormat)
              .setGcsDestination(gcsDestination)
              .build();
      MachineSpec machineSpec =
          MachineSpec.newBuilder()
              .setMachineType("n1-standard-2")
              .setAcceleratorType(AcceleratorType.NVIDIA_TESLA_T4)
              .setAcceleratorCount(1)
              .build();
      BatchDedicatedResources dedicatedResources =
          BatchDedicatedResources.newBuilder()
              .setMachineSpec(machineSpec)
              .setStartingReplicaCount(1)
              .setMaxReplicaCount(1)
              .build();
      String modelName = ModelName.of(project, location, model).toString();
      BatchPredictionJob batchPredictionJob =
          BatchPredictionJob.newBuilder()
              .setDisplayName(displayName)
              .setModel(modelName)
              .setModelParameters(modelParameters)
              .setInputConfig(inputConfig)
              .setOutputConfig(outputConfig)
              .setDedicatedResources(dedicatedResources)
              .build();
      LocationName parent = LocationName.of(project, location);
      BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob);
      System.out.format("response: %s\n", response);
      System.out.format("\tName: %s\n", response.getName());
    }
  }
}

Python

Vertex AI SDK for Python のインストールまたは更新の方法については、Vertex AI SDK for Python をインストールするをご覧ください。 詳細については、Python API リファレンス ドキュメントをご覧ください。

def create_batch_prediction_job_dedicated_resources_sample(
    project: str,
    location: str,
    model_resource_name: str,
    job_display_name: str,
    gcs_source: Union[str, Sequence[str]],
    gcs_destination: str,
    instances_format: str = "jsonl",
    machine_type: str = "n1-standard-2",
    accelerator_count: int = 1,
    accelerator_type: Union[str, aiplatform_v1.AcceleratorType] = "NVIDIA_TESLA_K80",
    starting_replica_count: int = 1,
    max_replica_count: int = 1,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    my_model = aiplatform.Model(model_resource_name)

    batch_prediction_job = my_model.batch_predict(
        job_display_name=job_display_name,
        gcs_source=gcs_source,
        gcs_destination_prefix=gcs_destination,
        instances_format=instances_format,
        machine_type=machine_type,
        accelerator_count=accelerator_count,
        accelerator_type=accelerator_type,
        starting_replica_count=starting_replica_count,
        max_replica_count=max_replica_count,
        sync=sync,
    )

    batch_prediction_job.wait()

    print(batch_prediction_job.display_name)
    print(batch_prediction_job.resource_name)
    print(batch_prediction_job.state)
    return batch_prediction_job

BigQuery

上記の REST の例では、ソースと出力先に Cloud Storage を使用しています。代わりに BigQuery を使用するには、次の変更を行います。

  • inputConfig フィールドを次のように変更します。

    "inputConfig": {
       "instancesFormat": "bigquery",
       "bigquerySource": {
          "inputUri": "bq://SOURCE_PROJECT_ID.SOURCE_DATASET_NAME.SOURCE_TABLE_NAME"
       }
    }
    
  • outputConfig フィールドを次のように変更します。

    "outputConfig": {
       "predictionsFormat":"bigquery",
       "bigqueryDestination":{
          "outputUri": "bq://DESTINATION_PROJECT_ID.DESTINATION_DATASET_NAME.DESTINATION_TABLE_NAME"
       }
     }
    
  • 次のように置き換えます。

    • SOURCE_PROJECT_ID: ソースの Google Cloud プロジェクトの ID
    • SOURCE_DATASET_NAME: ソースの BigQuery データセットの名前
    • SOURCE_TABLE_NAME: BigQuery ソーステーブルの名前
    • DESTINATION_PROJECT_ID: 出力先の Google Cloud プロジェクトの ID
    • DESTINATION_DATASET_NAME: 出力先の BigQuery データセットの名前
    • DESTINATION_TABLE_NAME: BigQuery 出力先テーブルの名前

特徴量の重要度

予測で特徴量の重要度の値が返されるようにするには、generateExplanation プロパティを true に設定します。なお、予測モデルでは、特徴量の重要度がサポートされていないため、バッチ予測リクエストに含めることはできません。

特徴量の重要度(特徴アトリビューション)は Vertex Explainable AI の一部です。

generateExplanationtrue に設定できるのは、説明用に Model を構成している場合か、BatchPredictionJobexplanationSpec フィールドを指定している場合だけです。

マシンタイプとレプリカ数を選択する

レプリカの数を増やすことで水平方向にスケーリングすると、より大きなマシンタイプを使用するよりも、線形かつ予測可能な方法でスループットが向上します。

一般に、ジョブに可能な限り小さいマシンタイプを指定し、レプリカの数を増やすことをおすすめします。

費用対効果の観点から、バッチ予測ジョブが 10 分以上実行されるようなレプリカ数を選択することをおすすめします。これは、レプリカノード時間あたりで課金され、各レプリカの起動にかかる約 5 分がこれに含まれるためです。数秒だけ処理してシャットダウンすると、コスト効率は良くありません。

一般的なガイダンスとして、数千のインスタンスの場合は starting_replica_count を数十にすることをおすすめします。数百万のインスタンスの場合は、starting_replica_count を数百にすることをおすすめします。次の数式を使用してレプリカの数を見積もることもできます。

N / (T * (60 / Tb))

ここで

  • N: ジョブ内のバッチ数。例: 100 万インスタンス ÷ 100 バッチサイズ = 10,000 バッチ。
  • T: バッチ予測ジョブの所要時間。例: 10 分。
  • Tb: レプリカが 1 つのバッチを処理するのに要する時間(秒)。たとえば、2 コアのマシンタイプでバッチあたり 1 秒。

この例では、10,000 バッチ ÷ (10 分 × (60 ÷ 1 秒)) は 17 レプリカに切り上げられます。

オンライン予測とは異なり、バッチ予測ジョブは自動スケーリングされません。すべての入力データが事前にわかっているため、ジョブの開始時にシステムによってデータが各レプリカにパーティショニングされます。システムで使用されるのは starting_replica_count パラメータです。max_replica_count パラメータは無視されます。

これらの推奨事項はすべておおよそのガイドラインです。すべてのモデルに最適なスループットを保証するものではありません。処理時間とコストの正確な予測を提示するものではありません。また、各シナリオで認識されているコストとスループットのトレードオフが最適であるとは限りません。これらを妥当な出発点として使用し、必要に応じて調整してください。モデルのスループットなどの特性を測定するには、最適なマシンタイプを見つけるのノートブックを実行します。

GPU または TPU アクセラレーションを使用するマシンの場合

次の追加の考慮事項とともに、前述のガイドラインに従ってください(これは、CPU のみのモデルにも適用されます)。

  • より多くの CPU と GPU(データの前処理用など)が必要になる場合があります。
  • GPU マシンタイプは起動に時間がかかる(10 分など)ため、バッチ予測ジョブでさらに長い時間(たとえば、10 分ではなく 20 分)をターゲットにして、予測の生成に合理的な比率の時間とコストが費やされるようにできます。

バッチ予測の結果を取得する

バッチ予測タスクが完了すると、リクエストで指定した Cloud Storage バケットまたは BigQuery のロケーションに予測の出力が保存されます。

バッチ予測結果の例

出力フォルダには、一連の JSON Lines ファイルが含まれます。

ファイル名は、{gcs_path}/prediction.results-{file_number}-of-{number_of_files_generated} になります。バッチ予測には分散的な性質があるため、ファイルの数は非確定的です。

ファイルの各行は入力のインスタンスに対応し、次の Key-Value ペアを持ちます。

  • prediction: 予測コンテナによって返された値が含まれます。
  • instance: ファイルリストの場合、Cloud Storage URI が含まれます。他のすべての入力形式の場合、HTTP リクエスト本文で予測コンテナに送信された値が含まれます。

例 1

HTTP リクエストに以下の対象が含まれている場合:

{
  "instances": [
    [1, 2, 3, 4],
    [5, 6, 7, 8]
]}

予測コンテナが以下の対象を返す場合:

{
  "predictions": [
    [0.1,0.9],
    [0.7,0.3]
  ],
}

JSON Lines 出力ファイルは次のようになります。

{ "instance": [1, 2, 3, 4], "prediction": [0.1,0.9]}
{ "instance": [5, 6, 7, 8], "prediction": [0.7,0.3]}

例 2

HTTP リクエストに以下の対象が含まれている場合:

{
  "instances": [
    {"values": [1, 2, 3, 4], "key": 1},
    {"values": [5, 6, 7, 8], "key": 2}
]}

予測コンテナが以下の対象を返す場合:

{
  "predictions": [
    {"result":1},
    {"result":0}
  ],
}

JSON Lines 出力ファイルは次のようになります。

{ "instance": {"values": [1, 2, 3, 4], "key": 1}, "prediction": {"result":1}}
{ "instance": {"values": [5, 6, 7, 8], "key": 2}, "prediction": {"result":0}}

Explainable AI を使用する

大量のデータに対して特徴ベースの説明を実行することはおすすめしません。特徴値のセットによっては各入力が数千のリクエストにファンアウトする可能性があり、処理時間と費用の大幅な増加につながる恐れがあるためです。一般に、特徴の重要性を把握するには小さなデータセットで十分です。

バッチ予測は例ベースの説明には対応していません。

ノートブック

次のステップ