使用 Spanner 模拟器生成机器学习预测

本页面介绍了如何使用 Spanner 模拟器为 GoogleSQL 方言数据库和 PostgreSQL 方言数据库生成机器学习预测。

Spanner Vertex AI 集成可与 Spanner 模拟器搭配使用,以通过 GoogleSQL 或 PostgreSQL 机器学习预测函数生成预测。模拟器是一个模仿 Spanner 服务器的二进制程序,还可用于单元测试和集成测试。您可以将模拟器作为开源项目来使用,也可以通过 Google Cloud CLI 在本地使用模拟器。如需详细了解机器学习预测函数,请参阅“Spanner Vertex AI 集成的工作原理是怎样的?”。

您可以将任何模型与模拟器搭配使用来生成预测。您还可以使用 Vertex AI Model Garden 中的模型或部署到 Vertex AI 端点的模型。由于模拟器不连接到 Vertex AI,因此模拟器无法验证从 Vertex AI Model Garden 使用或部署到 Vertex AI 端点的任何模型及其架构。

默认情况下,当您将某个预测函数与模拟器搭配使用时,该函数会根据提供的模型输入和模型输出架构生成随机值。您可以使用回调函数修改模型输入和输出,并根据特定行为生成预测结果。

准备工作

在使用 Spanner 模拟器生成机器学习预测之前,请完成以下步骤。

安装 Spanner 模拟器

您可以在本地安装模拟器,也可以使用 GitHub 仓库进行设置。

选择型号

使用 ML.PREDICT(对于 GoogleSQL)或 ML_PREDICT_ROW(对于 PostgreSQL)函数时,您必须指定机器学习模型的位置。您可以使用任何经过训练的模型。如果您选择在 Vertex AI Model Garden 中运行的模型或部署到 Vertex AI 端点的模型,则必须为这些模型提供 inputoutput 值。

如需详细了解 Spanner Vertex AI 集成,请参阅 Spanner Vertex AI 集成的工作原理是怎样的?

生成预测

您可以使用模拟器通过 Spanner 机器学习预测函数生成预测。

默认行为

您可以将部署到端点的任何模型与 Spanner 模拟器搭配使用来生成预测。以下示例使用名为 FraudDetection 的模型生成结果。

GoogleSQL

如需详细了解如何使用 ML.PREDICT 函数生成预测,请参阅使用 SQL 生成机器学习预测

注册模型

您必须先使用 CREATE MODEL 语句注册模型,并提供 inputoutput 值,然后才能将模型与 ML.PREDICT 函数搭配使用:

CREATE MODEL FraudDetection
INPUT (Amount INT64, Name STRING(MAX))
OUTPUT (Outcome BOOL)
REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/PROJECT_ID/locations/REGION_ID/endpoints/ENDPOINT_ID'
);

替换以下内容:

  • PROJECT_ID:模型所在的 Google Cloud 项目的 ID

  • REGION_ID:模型所在的 Google Cloud 区域的 ID,例如 us-central1

  • ENDPOINT_ID:模型端点的 ID

运行预测

使用 ML.PREDICT GoogleSQL 函数生成预测。

SELECT Outcome
FROM ML.PREDICT(
    MODEL FraudDetection,
    (SELECT 1000 AS Amount, "John Smith" AS Name))

此查询的预期输出为 TRUE

PostgreSQL

如需详细了解如何使用 spanner.ML_PREDICT_ROW 函数生成预测,请参阅使用 SQL 生成机器学习预测

运行预测

使用 spanner.ML_PREDICT_ROW PostgreSQL 函数生成预测。

SELECT (spanner.ml_predict_row(
'projects/`MODEL_ID`/locations/`REGION_ID`/endpoints/`ENDPOINT_ID`',
'{"instances": [{"Amount": "1000", "Name": "John Smith"}]}'
)->'predictions'->0->'Outcome')::boolean

替换以下内容:

  • PROJECT_ID:模型所在的 Google Cloud 项目的 ID

  • REGION_ID:模型所在的 Google Cloud 区域的 ID,例如 us-central1

  • ENDPOINT_ID:模型端点的 ID

此查询的预期输出为 TRUE

自定义回调

您可以使用自定义回调函数来实现所选模型行为,并将特定模型输入转换为输出。以下示例使用 Vertex AI Model Garden 中的 gemini-pro 模型和 Spanner 模拟器,通过自定义回调生成预测。

为模型使用自定义回调时,您必须复刻 Spanner 模拟器仓库,然后构建并部署它。如需详细了解如何构建和部署 Spanner 模拟器,请参阅 Spanner 模拟器快速入门

GoogleSQL

注册模型

您必须先使用 CREATE MODEL 语句注册模型,然后才能将模型与 ML.PREDICT 函数搭配使用:

CREATE MODEL GeminiPro
INPUT (prompt STRING(MAX))
OUTPUT (content STRING(MAX))
REMOTE OPTIONS (
endpoint = '//aiplatform.googleapis.com/projects/PROJECT_ID/locations/REGION_ID/publishers/google/models/gemini-pro',
default_batch_size = 1
);

由于模拟器不连接到 Vertex AI,因此您必须提供 inputoutput 值。

替换以下内容:

  • PROJECT_ID:模型所在的 Google Cloud 项目的 ID

  • REGION_ID:模型所在的 Google Cloud 区域的 ID,例如 us-central1

回调

使用回调向 GeminiPro 模型添加自定义逻辑。

absl::Status ModelEvaluator::Predict(
    const googlesql::Model* model,
    const CaseInsensitiveStringMap<const ModelColumn>& model_inputs,
    CaseInsensitiveStringMap<ModelColumn>& model_outputs) {
  // Custom logic for GeminiPro.
  if (model->Name() == "GeminiPro") {
    RET_CHECK(model_inputs.contains("prompt"));
    RET_CHECK(model_inputs.find("prompt")->second.value->type()->IsString());
    RET_CHECK(model_outputs.contains("content"));
    std::string content;

    // Process prompts used in tests.
    int64_t number;
    static LazyRE2 is_prime_prompt = {R"(Is (\d+) a prime number\?)"};
    if (RE2::FullMatch(
            model_inputs.find("prompt")->second.value->string_value(),
            *is_prime_prompt, &number)) {
        content = IsPrime(number) ? "Yes" : "No";
    } else {
        // Default response.
        content = "Sorry, I don't understand";
    }
    *model_outputs["content"].value = googlesql::values::String(content);
    return absl::OkStatus();
  }
  // Custom model prediction logic can be added here.
  return DefaultPredict(model, model_inputs, model_outputs);
}

运行预测

使用 ML.PREDICT GoogleSQL 函数生成预测。

SELECT content
    FROM ML.PREDICT(MODEL GeminiPro, (SELECT "Is 7 a prime number?" AS prompt))

此查询的预期输出为 "YES"

PostgreSQL

使用 spanner.ML_PREDICT_ROW PostgreSQL 函数生成预测。

回调

使用回调向 GeminiPro 模型添加自定义逻辑。

absl::Status ModelEvaluator::PgPredict(
    absl::string_view endpoint, const googlesql::JSONValueConstRef& instance,
    const googlesql::JSONValueConstRef& parameters,
    lesql::JSONValueRef prediction) {
  if (endpoint.ends_with("publishers/google/models/gemini-pro")) {
    RET_CHECK(instance.IsObject());
    RET_CHECK(instance.HasMember("prompt"));
    std::string content;

    // Process prompts used in tests.
    int64_t number;
    static LazyRE2 is_prime_prompt = {R"(Is (\d+) a prime number\?)"};
    if (RE2::FullMatch(instance.GetMember("prompt").GetString(),
                        *is_prime_prompt, &number)) {
        content = IsPrime(number) ? "Yes" : "No";
    } else {
        // Default response.
        content = "Sorry, I don't understand";
    }
    prediction.SetToEmptyObject();
    prediction.GetMember("content").SetString(content);
    return absl::OkStatus();
  }

  // Custom model prediction logic can be added here.
  return DefaultPgPredict(endpoint, instance, parameters, prediction);
}

运行预测

SELECT (spanner.ml_predict_row(
'projects/`PROJECT_ID`/locations/`REGION_ID`/publishers/google/models/gemini-pro',
'{"instances": [{"prompt": "Is 7 a prime number?"}]}'
)->'predictions'->0->'content')::text

替换以下内容:

  • PROJECT_ID:模型所在的 Google Cloud 项目的 ID

  • REGION_ID:模型所在的 Google Cloud 区域的 ID,例如 us-central1

此查询的预期输出为 "YES"

后续步骤