使用 Model Garden 和 Vertex AI TPU 支持的端点部署和推理 Gemma


在此教程中,您将使用 Model Garden 将 Gemma 2B 开放模型部署到受 TPU 支持的 Vertex AI 端点。您必须先将模型部署到端点,然后才能使用该模型执行在线预测。部署模型会将物理资源与模型相关联,以便以低延迟方式执行在线预测。

部署 Gemma 2B 模型后,您可以使用 PredictionServiceClient 推断训练好的模型,以获得在线预测结果。在线预测是指向部署到端点的模型发出的同步请求。

目标

本教程介绍了如何执行以下任务:

  • 使用 Model Garden 将 Gemma 2B 开放模型部署到受 TPU 支持的端点
  • 使用 PredictionServiceClient 获取在线预测结果

费用

在本文档中,您将使用 Google Cloud的以下收费组件:

您可使用价格计算器根据您的预计使用情况来估算费用。

新 Google Cloud 用户可能有资格申请免费试用

完成本文档中描述的任务后,您可以通过删除所创建的资源来避免继续计费。如需了解详情,请参阅清理

准备工作

本教程要求您执行以下操作:

  • 设置 Google Cloud 项目并启用 Vertex AI API
  • 在本地机器上:
    • 安装、初始化 Google Cloud CLI 并通过它进行身份验证
    • 安装相应语言的 SDK

设置 Google Cloud 项目

设置 Google Cloud 项目并启用 Vertex AI API。

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Verify that billing is enabled for your Google Cloud project.

  4. Enable the Vertex AI API.

    Enable the API

  5. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  6. Verify that billing is enabled for your Google Cloud project.

  7. Enable the Vertex AI API.

    Enable the API

  8. 设置 Google Cloud CLI

    在本地机器上,设置 Google Cloud CLI。

    1. 安装并初始化 Google Cloud CLI。

    2. 如果您之前安装了 gcloud CLI,请运行此命令,确保您的 gcloud 组件已更新。

      gcloud components update
    3. 如需使用 gcloud CLI 进行身份验证,请运行此命令以生成本地应用默认凭据 (ADC) 文件。该命令启动的 Web 流程用于提供您的用户凭据。

      gcloud auth application-default login

      如需了解详情,请参阅 gcloud CLI 身份验证配置和 ADC 配置

    为您的编程语言设置 SDK

    如需设置本教程中使用的环境,请安装相应语言的 Vertex AI SDK 和 Protocol Buffers 库。代码示例使用 Protocol Buffers 库中的函数将输入字典转换为 API 所需的 JSON 格式。

    在本地机器上,点击以下某个标签页,以安装相应编程语言的 SDK。

    Python

    在本地机器上,点击以下某个标签页,以安装相应编程语言的 SDK。

    • 运行以下命令,安装并更新 Vertex AI SDK for Python。

      pip3 install --upgrade "google-cloud-aiplatform>=1.64"
    • 运行以下命令,安装适用于 Python 的 Protocol Buffers 库。

      pip3 install --upgrade "protobuf>=5.28"

    Node.js

    通过运行以下命令来安装或更新 Node.js 版 aiplatform SDK。

    npm install @google-cloud/aiplatform

    Java

    如需将 google-cloud-aiplatform 添加为依赖项,请添加适合您环境的代码。

    带有 BOM 的 Maven

    将以下 HTML 添加到 pom.xml 中:

    <dependencyManagement>
    <dependencies>
      <dependency>
        <artifactId>libraries-bom</artifactId>
        <groupId>com.google.cloud</groupId>
        <scope>import</scope>
        <type>pom</type>
        <version>26.34.0</version>
      </dependency>
    </dependencies>
    </dependencyManagement>
    <dependencies>
    <dependency>
      <groupId>com.google.cloud</groupId>
      <artifactId>google-cloud-aiplatform</artifactId>
    </dependency>
    <dependency>
      <groupId>com.google.protobuf</groupId>
      <artifactId>protobuf-java-util</artifactId>
    </dependency>
    <dependency>
      <groupId>com.google.code.gson</groupId>
      <artifactId>gson</artifactId>
    </dependency>
    </dependencies>

    不带 BOM 的 Maven

    将以下内容添加到 pom.xml 中:

    <dependency>
      <groupId>com.google.cloud</groupId>
      <artifactId>google-cloud-aiplatform</artifactId>
      <version>1.1.0</version>
    </dependency>
    <dependency>
      <groupId>com.google.protobuf</groupId>
      <artifactId>protobuf-java-util</artifactId>
      <version>5.28</version>
    </dependency>
    <dependency>
      <groupId>com.google.code.gson</groupId>
      <artifactId>gson</artifactId>
      <version>2.11.0</version>
    </dependency>

    不带 BOM 的 Gradle

    将以下内容添加到 build.gradle 中:

    implementation 'com.google.cloud:google-cloud-aiplatform:1.1.0'

    Go

    运行以下命令安装这些 Go 软件包。

    go get cloud.google.com/go/aiplatform
    go get google.golang.org/protobuf
    go get github.com/googleapis/gax-go/v2

    使用 Model Garden 部署 Gemma

    您将 Gemma 2B 模型部署到针对中小型规模训练优化的 ct5lp-hightpu-1t Compute Engine 机器类型。此机器具有一个 TPU v5e 加速器。如需详细了解如何使用 TPU 训练模型,请参阅 Cloud TPU v5e 训练

    在本教程中,您将使用 Model Garden 中的模型卡片部署指令调优的 Gemma 2B 开放模型。具体模型版本为 gemma2-2b-it - -it 表示指令调优

    Gemma 2B 模型的参数大小较低,这意味着资源要求较低,部署灵活性更高。

    1. 在 Google Cloud 控制台中,前往 Model Garden 页面。

      转到 Model Garden

    2. 点击 Gemma 2 模型卡片。

      前往 Gemma 2

    3. 点击部署以打开部署模型窗格。

    4. 部署模型窗格中,指定以下详细信息。

      1. 对于部署环境,请点击 Vertex AI

      2. 部署模型部分中:

        1. 资源 ID 部分,选择 gemma-2b-it

        2. 对于模型名称端点名称,请接受默认值。例如:

          • 模型名称:gemma2-2b-it-1234567891234
          • 端点名称:gemma2-2b-it-mg-one-click-deploy

          记下端点名称。您需要此文件来查找代码示例中使用的端点 ID。

      3. 部署设置部分中:

        1. 接受基本设置的默认选项。

        2. 区域部分,接受默认值或从列表中选择一个区域。记下相应区域。您需要它来运行代码示例。

        3. 对于机器配置,请选择由 TPU 支持的实例:ct5lp-hightpu-1t (1 TPU_V5_LITEPOD; ct5lp-hightpu-1t)

    5. 点击部署。部署完成后,您会收到一封电子邮件,其中包含有关新端点的详细信息。您还可以依次点击在线预测 > 端点,然后选择相应区域,查看端点详细信息。

      转至 Endpoints

    使用 PredictionServiceClient 推断 Gemma 2B

    部署 Gemma 2B 后,您可以使用 PredictionServiceClient 针对提示“为什么天空是蓝色的?”获取在线预测结果

    代码参数

    PredictionServiceClient 代码示例需要您更新以下内容。

    • PROJECT_ID:如需查找项目 ID,请按以下步骤操作。

      1. 前往 Google Cloud 控制台中的欢迎页面。

        前往“欢迎”页面

      2. 从页面顶部的项目选择器中,选择您的项目。

        项目名称、项目编号和项目 ID 会显示在欢迎标题下方。

    • ENDPOINT_REGION:这是您部署端点的区域。

    • ENDPOINT_ID:如需查找端点 ID,请在控制台中查看或运行 gcloud ai endpoints list 命令。您需要从部署模型窗格中获取端点名称和区域。

      控制台

      您可以依次点击在线预测 > 端点,然后选择您的区域,以查看端点详情。请注意 ID 列中显示的数字。

      转至 Endpoints

      gcloud

      您可以运行 gcloud ai endpoints list 命令来查看端点详细信息。

      gcloud ai endpoints list \
        --region=ENDPOINT_REGION \
        --filter=display_name=ENDPOINT_NAME
      

      输出类似于以下内容。

      Using endpoint [https://us-central1-aiplatform.googleapis.com/]
      ENDPOINT_ID: 1234567891234567891
      DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
      

    示例代码

    在相应语言的示例代码中,更新 PROJECT_IDENDPOINT_REGIONENDPOINT_ID。然后运行代码。

    Python

    如需了解如何安装或更新 Vertex AI SDK for Python,请参阅安装 Vertex AI SDK for Python。 如需了解详情,请参阅 Python API 参考文档

    """
    Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
    """
    
    from google.cloud import aiplatform
    from google.protobuf import json_format
    from google.protobuf.struct_pb2 import Value
    
    # TODO(developer): Update & uncomment lines below
    # PROJECT_ID = "your-project-id"
    # ENDPOINT_REGION = "your-vertex-endpoint-region"
    # ENDPOINT_ID = "your-vertex-endpoint-id"
    
    # Default configuration
    config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}
    
    # Prompt used in the prediction
    prompt = "Why is the sky blue?"
    
    # Encapsulate the prompt in a correct format for TPUs
    # Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    input = {"prompt": prompt}
    input.update(config)
    
    # Convert input message to a list of GAPIC instances for model input
    instances = [json_format.ParseDict(input, Value())]
    
    # Create a client
    api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
    client = aiplatform.gapic.PredictionServiceClient(
        client_options={"api_endpoint": api_endpoint}
    )
    
    # Call the Gemma2 endpoint
    gemma2_end_point = (
        f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
    )
    response = client.predict(
        endpoint=gemma2_end_point,
        instances=instances,
    )
    text_responses = response.predictions
    print(text_responses[0])
    

    Node.js

    在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Node.js 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Node.js API 参考文档

    如需向 Vertex AI 进行身份验证,请设置应用默认凭证。 如需了解详情,请参阅为本地开发环境设置身份验证

    // Imports the Google Cloud Prediction Service Client library
    const {
      // TODO(developer): Uncomment PredictionServiceClient before running the sample.
      // PredictionServiceClient,
      helpers,
    } = require('@google-cloud/aiplatform');
    /**
     * TODO(developer): Update these variables before running the sample.
     */
    const projectId = 'your-project-id';
    const endpointRegion = 'your-vertex-endpoint-region';
    const endpointId = 'your-vertex-endpoint-id';
    
    // Prompt used in the prediction
    const prompt = 'Why is the sky blue?';
    
    // Encapsulate the prompt in a correct format for TPUs
    // Example format: [{prompt: 'Why is the sky blue?', temperature: 0.9}]
    const input = {
      prompt,
      // Parameters for default configuration
      maxOutputTokens: 1024,
      temperature: 0.9,
      topP: 1.0,
      topK: 1,
    };
    
    // Convert input message to a list of GAPIC instances for model input
    const instances = [helpers.toValue(input)];
    
    // TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
    // const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;
    
    // Create a client
    // predictionServiceClient = new PredictionServiceClient({apiEndpoint});
    
    // Call the Gemma2 endpoint
    const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;
    
    const [response] = await predictionServiceClient.predict({
      endpoint: gemma2Endpoint,
      instances,
    });
    
    const predictions = response.predictions;
    const text = predictions[0].stringValue;
    
    console.log('Predictions:', text);

    Java

    在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Java 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Java API 参考文档

    如需向 Vertex AI 进行身份验证,请设置应用默认凭证。 如需了解详情,请参阅为本地开发环境设置身份验证

    
    import com.google.cloud.aiplatform.v1.EndpointName;
    import com.google.cloud.aiplatform.v1.PredictResponse;
    import com.google.cloud.aiplatform.v1.PredictionServiceClient;
    import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
    import com.google.gson.Gson;
    import com.google.protobuf.InvalidProtocolBufferException;
    import com.google.protobuf.Value;
    import com.google.protobuf.util.JsonFormat;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    
    public class Gemma2PredictTpu {
      private final PredictionServiceClient predictionServiceClient;
    
      // Constructor to inject the PredictionServiceClient
      public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
        this.predictionServiceClient = predictionServiceClient;
      }
    
      public static void main(String[] args) throws IOException {
        // TODO(developer): Replace these variables before running the sample.
        String projectId = "YOUR_PROJECT_ID";
        String endpointRegion = "us-west1";
        String endpointId = "YOUR_ENDPOINT_ID";
    
        PredictionServiceSettings predictionServiceSettings =
            PredictionServiceSettings.newBuilder()
                .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
                .build();
        PredictionServiceClient predictionServiceClient =
            PredictionServiceClient.create(predictionServiceSettings);
        Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);
    
        creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
      }
    
      // Demonstrates how to run inference on a Gemma2 model
      // deployed to a Vertex AI endpoint with TPU accelerators.
      public String gemma2PredictTpu(String projectId, String region,
               String endpointId) throws IOException {
        Map<String, Object> paramsMap = new HashMap<>();
        paramsMap.put("temperature", 0.9);
        paramsMap.put("maxOutputTokens", 1024);
        paramsMap.put("topP", 1.0);
        paramsMap.put("topK", 1);
        Value parameters = mapToValue(paramsMap);
        // Prompt used in the prediction
        String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
        Value.Builder instanceValue = Value.newBuilder();
        JsonFormat.parser().merge(instance, instanceValue);
        // Encapsulate the prompt in a correct format for TPUs
        // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
        List<Value> instances = new ArrayList<>();
        instances.add(instanceValue.build());
    
        EndpointName endpointName = EndpointName.of(projectId, region, endpointId);
    
        PredictResponse predictResponse = this.predictionServiceClient
            .predict(endpointName, instances, parameters);
        String textResponse = predictResponse.getPredictions(0).getStringValue();
        System.out.println(textResponse);
        return textResponse;
      }
    
      private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
        Gson gson = new Gson();
        String json = gson.toJson(map);
        Value.Builder builder = Value.newBuilder();
        JsonFormat.parser().merge(json, builder);
        return builder.build();
      }
    }

    Go

    在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Go 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Go API 参考文档

    如需向 Vertex AI 进行身份验证,请设置应用默认凭证。 如需了解详情,请参阅为本地开发环境设置身份验证

    import (
    	"context"
    	"fmt"
    	"io"
    
    	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
    
    	"google.golang.org/protobuf/types/known/structpb"
    )
    
    // predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
    func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
    	ctx := context.Background()
    
    	// Note: client can be initialized in the following way:
    	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
    	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
    	// if err != nil {
    	// 	return fmt.Errorf("unable to create prediction client: %v", err)
    	// }
    	// defer client.Close()
    
    	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
    	prompt := "Why is the sky blue?"
    	parameters := map[string]interface{}{
    		"temperature":     0.9,
    		"maxOutputTokens": 1024,
    		"topP":            1.0,
    		"topK":            1,
    	}
    
    	// Encapsulate the prompt in a correct format for TPUs.
    	// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    	promptValue, err := structpb.NewValue(map[string]interface{}{
    		"prompt":     prompt,
    		"parameters": parameters,
    	})
    	if err != nil {
    		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
    		return err
    	}
    
    	req := &aiplatformpb.PredictRequest{
    		Endpoint:  gemma2Endpoint,
    		Instances: []*structpb.Value{promptValue},
    	}
    
    	resp, err := client.Predict(ctx, req)
    	if err != nil {
    		return err
    	}
    
    	prediction := resp.GetPredictions()
    	value := prediction[0].GetStringValue()
    	fmt.Fprintf(w, "%v", value)
    
    	return nil
    }
    

    清理

    为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

    删除项目

    1. In the Google Cloud console, go to the Manage resources page.

      Go to Manage resources

    2. In the project list, select the project that you want to delete, and then click Delete.
    3. In the dialog, type the project ID, and then click Shut down to delete the project.

    删除各个资源

    如果您要保留项目,请删除本教程中使用的资源:

    • 取消部署模型并删除端点
    • 从 Model Registry 中删除模型

    取消部署模型并删除端点

    请使用以下方法之一取消部署模型并删除端点。

    控制台

    1. 在 Google Cloud 控制台中,依次点击在线预测端点

      转到“端点”页面

    2. 区域下拉列表中,选择部署端点的区域。

    3. 点击端点名称以打开详情页面。例如:gemma2-2b-it-mg-one-click-deploy

    4. Gemma 2 (Version 1) 模型所对应的行中,点击 操作,然后点击从端点取消部署模型

    5. 从端点取消部署模型对话框中,点击取消部署

    6. 点击返回按钮,返回到端点页面。

      转到“端点”页面

    7. gemma2-2b-it-mg-one-click-deploy 行末尾,点击 操作,然后选择删除端点

    8. 在确认提示中,点击确认

    gcloud

    如需使用 Google Cloud CLI 取消部署模型并删除端点,请按以下步骤操作。

    在这些命令中,替换以下内容:

    • PROJECT_ID 替换为您的项目名称
    • LOCATION_ID 替换为您部署模型和端点的区域
    • ENDPOINT_ID 替换为端点 ID
    • DEPLOYED_MODEL_NAME 替换为模型的显示名称
    • DEPLOYED_MODEL_ID 替换为模型 ID
    1. 运行 gcloud ai endpoints list 命令获取端点 ID。此命令会列出项目中所有端点的 ID。记下本教程中使用的端点的 ID。

      gcloud ai endpoints list \
          --project=PROJECT_ID \
          --region=LOCATION_ID
      

      输出类似于以下内容。在输出中,该 ID 称为 ENDPOINT_ID

      Using endpoint [https://us-central1-aiplatform.googleapis.com/]
      ENDPOINT_ID: 1234567891234567891
      DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
      
    2. 运行 gcloud ai models describe 命令获取模型 ID。记下您在本教程中部署的模型的 ID。

      gcloud ai models describe DEPLOYED_MODEL_NAME \
          --project=PROJECT_ID \
          --region=LOCATION_ID
      

      缩略输出如下所示。在输出中,该 ID 称为 deployedModelId

      Using endpoint [https://us-central1-aiplatform.googleapis.com/]
      artifactUri: [URI removed]
      baseModelSource:
        modelGardenSource:
          publicModelName: publishers/google/models/gemma2
      ...
      deployedModels:
      - deployedModelId: '1234567891234567891'
        endpoint: projects/12345678912/locations/us-central1/endpoints/12345678912345
      displayName: gemma2-2b-it-12345678912345
      etag: [ETag removed]
      modelSourceInfo:
        sourceType: MODEL_GARDEN
      name: projects/123456789123/locations/us-central1/models/gemma2-2b-it-12345678912345
      ...
      
    3. 从端点取消部署模型。您需要使用之前命令中的端点 ID 和模型 ID。

      gcloud ai endpoints undeploy-model ENDPOINT_ID \
          --project=PROJECT_ID \
          --region=LOCATION_ID \
          --deployed-model-id=DEPLOYED_MODEL_ID
      

      此命令没有任何输出。

    4. 运行 gcloud ai endpoints delete 命令以删除端点。

      gcloud ai endpoints delete ENDPOINT_ID \
          --project=PROJECT_ID \
          --region=LOCATION_ID
      

      出现提示时,输入 y 以确认。此命令没有任何输出。

    删除模型

    控制台

    1. 从 Google Cloud 控制台的 Vertex AI 部分,进入 Model Registry 页面。

      进入 Model Registry 页面

    2. 区域下拉列表中,选择部署模型的区域。

    3. gemma2-2b-it-1234567891234 行末尾,点击 操作

    4. 选择删除模型

      删除模型时,所有关联的模型版本和评估都会从 Google Cloud 项目中删除。

    5. 在确认提示中,点击删除

    gcloud

    如需使用 Google Cloud CLI 删除模型,请向 gcloud ai models delete 命令提供模型的显示名称和区域。

    gcloud ai models delete DEPLOYED_MODEL_NAME \
        --project=PROJECT_ID \
        --region=LOCATION_ID
    

    DEPLOYED_MODEL_NAME 替换为模型的显示名称。 将 PROJECT_ID 替换为您的项目名称。将 LOCATION_ID 替换为您部署模型的区域。

    后续步骤