Spanner エミュレータを使用して ML 予測を生成する

このページでは、Spanner エミュレータを使用して ML 予測を生成する方法について説明します。

Spanner Vertex AI インテグレーションと Spanner エミュレータを使用して、GoogleSQL または PostgreSQL ML 予測関数を使用して予測を生成できます。エミュレータは Spanner サーバーを模倣したバイナリであり、単体テストと統合テストでも使用できます。エミュレータは、オープンソース プロジェクトとして、または Google Cloud CLI を使用してローカルで使用できます。ML 予測関数の詳細については、Spanner Vertex AI インテグレーションの仕組みをご覧ください。

エミュレータで任意のモデルを使用して予測を生成できます。 Vertex AI Model Garden のモデルや、Vertex AI エンドポイントにデプロイされたモデルを使用することもできます。エミュレータは Vertex AI に接続しないため、Vertex AI Model Garden で使用されているモデルや Vertex AI エンドポイントにデプロイされているモデルのモデルまたはそのスキーマを検証できません。

デフォルトでは、エミュレータで予測関数を使用すると、指定されたモデル入力とモデル出力スキーマに基づいてランダムな値が生成されます。コールバック関数を使用すると、モデルの入力と出力を変更し、特定の動作に基づいて予測結果を生成できます。

始める前に

Spanner エミュレータをインストールする

エミュレータをローカルにインストールするか、GitHub リポジトリを使用して設定します。

モデルの選択

ML.PREDICT(GoogleSQL の場合)または ML_PREDICT_ROW(PostgreSQL の場合)関数を使用する場合は、ML モデルの場所を指定する必要があります。任意のトレーニング済みモデルを使用できます。Vertex AI Model Garden で実行されているモデルまたは Vertex AI エンドポイントにデプロイされているモデルを選択する場合は、これらのモデルの input 値と output 値を指定する必要があります。

Spanner Vertex AI インテグレーションの詳細については、Spanner Vertex AI インテグレーションの仕組みをご覧ください。

予測を生成する

エミュレータを使用して、Spanner ML 予測関数を使用して予測を生成できます。

デフォルトの動作

Spanner エミュレータでエンドポイントにデプロイされた任意のモデルを使用して予測を生成できます。次の例では、FraudDetection というモデルを使用して結果を生成します。

GoogleSQL

ML.PREDICT 関数を使用して予測を生成する方法については、SQL を使用して ML 予測を生成するをご覧ください。

モデルを登録する

ML.PREDICT 関数でモデルを使用する前に、CREATE MODEL ステートメントを使用してモデルを登録し、inputoutput の値を指定する必要があります。

CREATE MODEL FraudDetection
INPUT (Amount INT64, Name STRING(MAX))
OUTPUT (Outcome BOOL)
REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/PROJECT_ID/locations/REGION_ID/endpoints/ENDPOINT_ID'
);

以下を置き換えます。

  • PROJECT_ID: モデルが配置されている Google Cloud プロジェクトの ID

  • REGION_ID: モデルが配置されている Google Cloud リージョンの ID(例: us-central1

  • ENDPOINT_ID: モデル エンドポイントの ID

予測を実行する

ML.PREDICT GoogleSQL 関数を使用して予測を生成します。

SELECT Outcome
FROM ML.PREDICT(
    MODEL FraudDetection,
    (SELECT 1000 AS Amount, "John Smith" AS Name))

このコマンドの想定される出力は TRUE です。

PostgreSQL

spanner.ML_PREDICT_ROW 関数を使用して予測を生成する方法については、SQL を使用して ML 予測を生成するをご覧ください。

予測を実行する

spanner.ML_PREDICT_ROW PostgreSQL 関数を使用して予測を生成します。

SELECT (spanner.ml_predict_row(
'projects/`MODEL_ID`/locations/`REGION_ID`/endpoints/`ENDPOINT_ID`',
'{"instances": [{"Amount": "1000", "Name": "John Smith"}]}'
)->'predictions'->0->'Outcome')::boolean

以下を置き換えます。

  • PROJECT_ID: モデルが配置されている Google Cloud プロジェクトの ID

  • REGION_ID: モデルが配置されている Google Cloud リージョンの ID(例: us-central1

  • ENDPOINT_ID: モデル エンドポイントの ID

このコマンドの想定される出力は TRUE です。

カスタム コールバック

カスタム コールバック関数を使用すると、選択したモデルの動作を実装し、特定のモデル入力を出力に変換できます。次の例では、Vertex AI Model Garden の gemini-pro モデルと Spanner エミュレータを使用して、カスタム コールバックを使用して予測を生成します。

モデルにカスタム コールバックを使用する場合は、Spanner エミュレータ リポジトリフォークしてから、ビルドしてデプロイする必要があります。Spanner エミュレータをビルドしてデプロイする方法については、Spanner エミュレータのクイックスタートをご覧ください。

GoogleSQL

モデルを登録する

ML.PREDICT 関数でモデルを使用する前に、CREATE MODEL ステートメントを使用してモデルを登録する必要があります。

CREATE MODEL GeminiPro
INPUT (prompt STRING(MAX))
OUTPUT (content STRING(MAX))
REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/PROJECT_ID/locations/REGION_ID/publishers/google/models/gemini-pro',
default_batch_size = 1
);

エミュレータは Vertex AI に接続しないため、inputoutput の値を指定する必要があります。

以下を置き換えます。

  • PROJECT_ID: モデルが配置されている Google Cloud プロジェクトの ID

  • REGION_ID: モデルが配置されている Google Cloud リージョンの ID(例: us-central1

コールバック

コールバックを使用して、GeminiPro モデルにカスタム ロジックを追加します。

absl::Status ModelEvaluator::Predict(
    const googlesql::Model* model,
    const CaseInsensitiveStringMap<const ModelColumn>& model_inputs,
    CaseInsensitiveStringMap<ModelColumn>& model_outputs) {
  // Custom logic for GeminiPro.
  if (model->Name() == "GeminiPro") {
    RET_CHECK(model_inputs.contains("prompt"));
    RET_CHECK(model_inputs.find("prompt")->second.value->type()->IsString());
    RET_CHECK(model_outputs.contains("content"));
    std::string content;

    // Process prompts used in tests.
    int64_t number;
    static LazyRE2 is_prime_prompt = {R"(Is (\d+) a prime number\?)"};
    if (RE2::FullMatch(
            model_inputs.find("prompt")->second.value->string_value(),
            *is_prime_prompt, &number)) {
        content = IsPrime(number) ? "Yes" : "No";
    } else {
        // Default response.
        content = "Sorry, I don't understand";
    }
    *model_outputs["content"].value = googlesql::values::String(content);
    return absl::OkStatus();
  }
  // Custom model prediction logic can be added here.
  return DefaultPredict(model, model_inputs, model_outputs);
}

予測を実行する

ML.PREDICT GoogleSQL 関数を使用して予測を生成します。

SELECT content
    FROM ML.PREDICT(MODEL GeminiPro, (SELECT "Is 7 a prime number?" AS prompt))

このコマンドの想定される出力は "YES" です。

PostgreSQL

spanner.ML_PREDICT_ROW PostgreSQL 関数を使用して予測を生成します。

コールバック

コールバックを使用して、GeminiPro モデルにカスタム ロジックを追加します。

absl::Status ModelEvaluator::PgPredict(
    absl::string_view endpoint, const googlesql::JSONValueConstRef& instance,
    const googlesql::JSONValueConstRef& parameters,
    lesql::JSONValueRef prediction) {
  if (endpoint.ends_with("publishers/google/models/gemini-pro")) {
    RET_CHECK(instance.IsObject());
    RET_CHECK(instance.HasMember("prompt"));
    std::string content;

    // Process prompts used in tests.
    int64_t number;
    static LazyRE2 is_prime_prompt = {R"(Is (\d+) a prime number\?)"};
    if (RE2::FullMatch(instance.GetMember("prompt").GetString(),
                        *is_prime_prompt, &number)) {
        content = IsPrime(number) ? "Yes" : "No";
    } else {
        // Default response.
        content = "Sorry, I don't understand";
    }
    prediction.SetToEmptyObject();
    prediction.GetMember("content").SetString(content);
    return absl::OkStatus();
  }

  // Custom model prediction logic can be added here.
  return DefaultPgPredict(endpoint, instance, parameters, prediction);
}

予測を実行する


SELECT (spanner.ml_predict_row(
'projects/`PROJECT_ID`/locations/`REGION_ID`/publishers/google/models/gemini-pro',
'{"instances": [{"prompt": "Is 7 a prime number?"}]}'
)->'predictions'->0->'content')::text

以下を置き換えます。

  • PROJECT_ID: モデルが配置されている Google Cloud プロジェクトの ID

  • REGION_ID: モデルが配置されている Google Cloud リージョンの ID(例: us-central1

このコマンドの想定される出力は "YES" です。

次のステップ