使用 Model Garden 和 Vertex AI TPU 支持的端点部署和推理 Gemma

在此教程中,您将使用 Model Garden 将 Gemma 2B 开放模型部署到受 TPU 支持的 Vertex AI 端点。您必须先将模型部署到端点,然后才能使用该模型执行在线预测。部署模型会将物理资源与模型相关联,以便以低延迟方式执行在线预测。

部署 Gemma 2B 模型后,您可以使用 PredictionServiceClient 获取在线预测结果,以通过经过训练的模型进行推理。在线预测是指向部署到端点的模型发出的同步请求。

目标

本教程介绍了如何执行以下任务:

  • 使用 Model Garden 将 Gemma 2B 开放模型部署到受 TPU 支持的端点
  • 使用 PredictionServiceClient 获取在线预测结果

费用

在本文档中,您将使用 Google Cloud的以下收费组件:

如需根据您的预计使用量来估算费用,请使用价格计算器

新 Google Cloud 用户可能有资格申请免费试用

完成本文档中描述的任务后,您可以通过删除所创建的资源来避免继续计费。如需了解详情,请参阅清理

准备工作

本教程要求您执行以下操作:

  • 设置 Google Cloud 项目并启用 Vertex AI API
  • 在您的本地机器上:
    • 安装、初始化 Google Cloud CLI 并向其进行身份验证
    • 安装相应编程语言的 SDK

设置 Google Cloud 项目

设置 Google Cloud 项目并启用 Vertex AI API。

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  3. Verify that billing is enabled for your Google Cloud project.

  4. Enable the Vertex AI API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  5. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.

    Go to project selector

  6. Verify that billing is enabled for your Google Cloud project.

  7. Enable the Vertex AI API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    Enable the API

  8. 设置 Google Cloud CLI

    在您的本地机器上,设置 Google Cloud CLI。

    1. 安装并初始化 Google Cloud CLI。

    2. 如果您之前安装了 gcloud CLI,请确保通过运行此命令更新您的 gcloud 组件。

      gcloud components update
    3. 如要向 gcloud CLI 进行身份验证,请通过运行此命令生成本地应用默认凭证 (ADC) 文件。该命令启动的 Web 流程用于提供您的用户凭证。

      gcloud auth application-default login

      如需了解详情,请参阅 gcloud CLI 身份验证配置和 ADC 配置

    设置相应编程语言的 SDK

    如需设置本教程中使用的环境,请安装相应编程语言的 Vertex AI SDK 及协议缓冲区库。代码示例会使用协议缓冲区库中的函数将输入字典转换为 API 所需的 JSON 格式。

    在您的本地机器上,点击以下标签页之一,安装相应编程语言的 SDK。

    Python

    在您的本地机器上,点击以下标签页之一,安装相应编程语言的 SDK。

    • 运行以下命令,安装并更新 Vertex AI SDK for Python。

      pip3 install --upgrade "google-cloud-aiplatform>=1.64"
    • 运行以下命令,安装 Python 版协议缓冲区库。

      pip3 install --upgrade "protobuf>=5.28"

    Node.js

    运行以下命令,安装或更新 Node.js 版 aiplatform SDK。

    npm install @google-cloud/aiplatform

    Java

    如需将 google-cloud-aiplatform 添加为依赖项,请为您的环境添加相应的代码。

    带有 BOM 的 Maven

    将以下 HTML 添加到 pom.xml 中:

    <dependencyManagement>
    <dependencies>
      <dependency>
        <artifactId>libraries-bom</artifactId>
        <groupId>com.google.cloud</groupId>
        <scope>import</scope>
        <type>pom</type>
        <version>26.34.0</version>
      </dependency>
    </dependencies>
    </dependencyManagement>
    <dependencies>
    <dependency>
      <groupId>com.google.cloud</groupId>
      <artifactId>google-cloud-aiplatform</artifactId>
    </dependency>
    <dependency>
      <groupId>com.google.protobuf</groupId>
      <artifactId>protobuf-java-util</artifactId>
    </dependency>
    <dependency>
      <groupId>com.google.code.gson</groupId>
      <artifactId>gson</artifactId>
    </dependency>
    </dependencies>

    不带 BOM 的 Maven

    将以下内容添加到 pom.xml 中:

    <dependency>
      <groupId>com.google.cloud</groupId>
      <artifactId>google-cloud-aiplatform</artifactId>
      <version>1.1.0</version>
    </dependency>
    <dependency>
      <groupId>com.google.protobuf</groupId>
      <artifactId>protobuf-java-util</artifactId>
      <version>5.28</version>
    </dependency>
    <dependency>
      <groupId>com.google.code.gson</groupId>
      <artifactId>gson</artifactId>
      <version>2.11.0</version>
    </dependency>

    不带 BOM 的 Gradle

    将以下内容添加到 build.gradle 中:

    implementation 'com.google.cloud:google-cloud-aiplatform:1.1.0'

    Go

    运行以下命令,安装下面的 Go 软件包。

    go get cloud.google.com/go/aiplatform
    go get google.golang.org/protobuf
    go get github.com/googleapis/gax-go/v2

    使用 Model Garden 部署 Gemma

    将 Gemma 2B 模型部署到针对中小规模训练优化的 ct5lp-hightpu-1t Compute Engine 机器类型。该类型的机器有一个 TPU v5e 加速器。如需详细了解如何使用 TPU 训练模型,请参阅 Cloud TPU v5e 训练

    在本教程中,您将使用 Model Garden 中的模型卡片部署指令调优的 Gemma 2B 开放模型。具体模型版本为 gemma2-2b-it - -it 表示指令调优

    Gemma 2B 模型的参数规模较小,这意味着对资源的要求较低,同时能够提供较高的部署灵活性。

    1. 在 Google Cloud 控制台中,前往 Model Garden 页面。

      转到 Model Garden

    2. 点击 Gemma 2 模型卡片。

      前往 Gemma 2

    3. 点击部署以打开部署模型窗格。

    4. 部署模型窗格中,指定以下详细信息。

      1. 部署环境部分,点击 Vertex AI

      2. 部署模型部分:

        1. 资源 ID 部分,选择 gemma-2b-it

        2. 对于模型名称端点名称,接受默认值即可。例如:

          • 模型名称:gemma2-2b-it-1234567891234
          • 端点名称:gemma2-2b-it-mg-one-click-deploy

          记下端点名称。您需要用该名称来查找代码示例中使用的端点 ID。

      3. 部署设置部分:

        1. 对于基本设置,接受默认选项即可。

        2. 区域部分,接受默认值或从列表中选择一个区域。记下相应区域。代码示例需要用到该区域。

        3. 机器规格部分,选择受 TPU 支持的实例:ct5lp-hightpu-1t (1 TPU_V5_LITEPOD; ct5lp-hightpu-1t)

    5. 点击部署。部署完成后,您会收到一封邮件,其中包含有关新端点的详细信息。您也可以通过依次点击在线预测 > 端点并选择相应区域,来查看端点详细信息。

      转至 Endpoints

    使用 PredictionServiceClient 推断 Gemma 2B

    部署 Gemma 2B 后,您可以使用 PredictionServiceClient 获取以下提示的在线预测结果:“为什么天空是蓝色的?”。

    代码参数

    PredictionServiceClient 代码示例需要您更新以下内容。

    • PROJECT_ID:如需查找项目 ID,请按以下步骤操作。

      1. 前往 Google Cloud 控制台中的欢迎页面。

        前往“欢迎”页面

      2. 从页面顶部的项目选择器中,选择您的项目。

        项目名称、项目编号和项目 ID 会显示在欢迎标头后面。

    • ENDPOINT_REGION:这是您在其中部署端点的区域。

    • ENDPOINT_ID:如要查找端点 ID,您可以在控制台中查看,或者运行 gcloud ai endpoints list 命令。您需要记下部署模型窗格中的端点名称和区域。

      控制台

      您可以通过依次点击在线预测 > 端点并选择相应区域,来查看端点详细信息。请注意 ID 列中显示的数字。

      转至 Endpoints

      gcloud

      您可以运行 gcloud ai endpoints list 命令来查看端点详细信息。

      gcloud ai endpoints list \
        --region=ENDPOINT_REGION \
        --filter=display_name=ENDPOINT_NAME
      

      输出类似于以下内容。

      Using endpoint [https://us-central1-aiplatform.googleapis.com/]
      ENDPOINT_ID: 1234567891234567891
      DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
      

    示例代码

    在相应编程语言的示例代码中,更新 PROJECT_IDENDPOINT_REGIONENDPOINT_ID。然后运行代码。

    Python

    如需了解如何安装或更新 Vertex AI SDK for Python,请参阅安装 Vertex AI SDK for Python。 如需了解详情,请参阅 Python API 参考文档

    """
    Sample to run inference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
    """
    
    from google.cloud import aiplatform
    from google.protobuf import json_format
    from google.protobuf.struct_pb2 import Value
    
    # TODO(developer): Update & uncomment lines below
    # PROJECT_ID = "your-project-id"
    # ENDPOINT_REGION = "your-vertex-endpoint-region"
    # ENDPOINT_ID = "your-vertex-endpoint-id"
    
    # Default configuration
    config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}
    
    # Prompt used in the prediction
    prompt = "Why is the sky blue?"
    
    # Encapsulate the prompt in a correct format for TPUs
    # Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    input = {"prompt": prompt}
    input.update(config)
    
    # Convert input message to a list of GAPIC instances for model input
    instances = [json_format.ParseDict(input, Value())]
    
    # Create a client
    api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
    client = aiplatform.gapic.PredictionServiceClient(
        client_options={"api_endpoint": api_endpoint}
    )
    
    # Call the Gemma2 endpoint
    gemma2_end_point = (
        f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
    )
    response = client.predict(
        endpoint=gemma2_end_point,
        instances=instances,
    )
    text_responses = response.predictions
    print(text_responses[0])
    

    Node.js

    在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Node.js 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Node.js API 参考文档

    如需向 Vertex AI 进行身份验证,请设置应用默认凭证。 如需了解详情,请参阅为本地开发环境设置身份验证

    // Imports the Google Cloud Prediction Service Client library
    const {
      // TODO(developer): Uncomment PredictionServiceClient before running the sample.
      // PredictionServiceClient,
      helpers,
    } = require('@google-cloud/aiplatform');
    /**
     * TODO(developer): Update these variables before running the sample.
     */
    const projectId = 'your-project-id';
    const endpointRegion = 'your-vertex-endpoint-region';
    const endpointId = 'your-vertex-endpoint-id';
    
    // Prompt used in the prediction
    const prompt = 'Why is the sky blue?';
    
    // Encapsulate the prompt in a correct format for TPUs
    // Example format: [{prompt: 'Why is the sky blue?', temperature: 0.9}]
    const input = {
      prompt,
      // Parameters for default configuration
      maxOutputTokens: 1024,
      temperature: 0.9,
      topP: 1.0,
      topK: 1,
    };
    
    // Convert input message to a list of GAPIC instances for model input
    const instances = [helpers.toValue(input)];
    
    // TODO(developer): Uncomment apiEndpoint and predictionServiceClient before running the sample.
    // const apiEndpoint = `${endpointRegion}-aiplatform.googleapis.com`;
    
    // Create a client
    // predictionServiceClient = new PredictionServiceClient({apiEndpoint});
    
    // Call the Gemma2 endpoint
    const gemma2Endpoint = `projects/${projectId}/locations/${endpointRegion}/endpoints/${endpointId}`;
    
    const [response] = await predictionServiceClient.predict({
      endpoint: gemma2Endpoint,
      instances,
    });
    
    const predictions = response.predictions;
    const text = predictions[0].stringValue;
    
    console.log('Predictions:', text);

    Java

    在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Java 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Java API 参考文档

    如需向 Vertex AI 进行身份验证,请设置应用默认凭证。 如需了解详情,请参阅为本地开发环境设置身份验证

    
    import com.google.cloud.aiplatform.v1.EndpointName;
    import com.google.cloud.aiplatform.v1.PredictResponse;
    import com.google.cloud.aiplatform.v1.PredictionServiceClient;
    import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
    import com.google.gson.Gson;
    import com.google.protobuf.InvalidProtocolBufferException;
    import com.google.protobuf.Value;
    import com.google.protobuf.util.JsonFormat;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    
    public class Gemma2PredictTpu {
      private final PredictionServiceClient predictionServiceClient;
    
      // Constructor to inject the PredictionServiceClient
      public Gemma2PredictTpu(PredictionServiceClient predictionServiceClient) {
        this.predictionServiceClient = predictionServiceClient;
      }
    
      public static void main(String[] args) throws IOException {
        // TODO(developer): Replace these variables before running the sample.
        String projectId = "YOUR_PROJECT_ID";
        String endpointRegion = "us-west1";
        String endpointId = "YOUR_ENDPOINT_ID";
    
        PredictionServiceSettings predictionServiceSettings =
            PredictionServiceSettings.newBuilder()
                .setEndpoint(String.format("%s-aiplatform.googleapis.com:443", endpointRegion))
                .build();
        PredictionServiceClient predictionServiceClient =
            PredictionServiceClient.create(predictionServiceSettings);
        Gemma2PredictTpu creator = new Gemma2PredictTpu(predictionServiceClient);
    
        creator.gemma2PredictTpu(projectId, endpointRegion, endpointId);
      }
    
      // Demonstrates how to run inference on a Gemma2 model
      // deployed to a Vertex AI endpoint with TPU accelerators.
      public String gemma2PredictTpu(String projectId, String region,
               String endpointId) throws IOException {
        Map<String, Object> paramsMap = new HashMap<>();
        paramsMap.put("temperature", 0.9);
        paramsMap.put("maxOutputTokens", 1024);
        paramsMap.put("topP", 1.0);
        paramsMap.put("topK", 1);
        Value parameters = mapToValue(paramsMap);
        // Prompt used in the prediction
        String instance = "{ \"prompt\": \"Why is the sky blue?\"}";
        Value.Builder instanceValue = Value.newBuilder();
        JsonFormat.parser().merge(instance, instanceValue);
        // Encapsulate the prompt in a correct format for TPUs
        // Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
        List<Value> instances = new ArrayList<>();
        instances.add(instanceValue.build());
    
        EndpointName endpointName = EndpointName.of(projectId, region, endpointId);
    
        PredictResponse predictResponse = this.predictionServiceClient
            .predict(endpointName, instances, parameters);
        String textResponse = predictResponse.getPredictions(0).getStringValue();
        System.out.println(textResponse);
        return textResponse;
      }
    
      private static Value mapToValue(Map<String, Object> map) throws InvalidProtocolBufferException {
        Gson gson = new Gson();
        String json = gson.toJson(map);
        Value.Builder builder = Value.newBuilder();
        JsonFormat.parser().merge(json, builder);
        return builder.build();
      }
    }

    Go

    在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Go 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Go API 参考文档

    如需向 Vertex AI 进行身份验证,请设置应用默认凭证。 如需了解详情,请参阅为本地开发环境设置身份验证

    import (
    	"context"
    	"fmt"
    	"io"
    
    	"cloud.google.com/go/aiplatform/apiv1/aiplatformpb"
    
    	"google.golang.org/protobuf/types/known/structpb"
    )
    
    // predictTPU demonstrates how to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accelerators.
    func predictTPU(w io.Writer, client PredictionsClient, projectID, location, endpointID string) error {
    	ctx := context.Background()
    
    	// Note: client can be initialized in the following way:
    	// apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
    	// client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
    	// if err != nil {
    	// 	return fmt.Errorf("unable to create prediction client: %v", err)
    	// }
    	// defer client.Close()
    
    	gemma2Endpoint := fmt.Sprintf("projects/%s/locations/%s/endpoints/%s", projectID, location, endpointID)
    	prompt := "Why is the sky blue?"
    	parameters := map[string]interface{}{
    		"temperature":     0.9,
    		"maxOutputTokens": 1024,
    		"topP":            1.0,
    		"topK":            1,
    	}
    
    	// Encapsulate the prompt in a correct format for TPUs.
    	// Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
    	promptValue, err := structpb.NewValue(map[string]interface{}{
    		"prompt":     prompt,
    		"parameters": parameters,
    	})
    	if err != nil {
    		fmt.Fprintf(w, "unable to convert prompt to Value: %v", err)
    		return err
    	}
    
    	req := &aiplatformpb.PredictRequest{
    		Endpoint:  gemma2Endpoint,
    		Instances: []*structpb.Value{promptValue},
    	}
    
    	resp, err := client.Predict(ctx, req)
    	if err != nil {
    		return err
    	}
    
    	prediction := resp.GetPredictions()
    	value := prediction[0].GetStringValue()
    	fmt.Fprintf(w, "%v", value)
    
    	return nil
    }
    

    清理

    为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

    删除项目

    1. In the Google Cloud console, go to the Manage resources page.

      Go to Manage resources

    2. In the project list, select the project that you want to delete, and then click Delete.
    3. In the dialog, type the project ID, and then click Shut down to delete the project.

    删除各个资源

    如果您想保留项目,则删除本教程中使用的资源即可:

    • 取消部署模型并删除端点
    • 从 Model Registry 中删除模型

    取消部署模型并删除端点

    请使用以下方法之一取消部署模型并删除端点。

    控制台

    1. 在 Google Cloud 控制台中,依次点击在线预测端点

      转到“端点”页面

    2. 区域下拉列表中,选择您在其中部署了端点的区域。

    3. 点击端点名称以打开详情页面。例如:gemma2-2b-it-mg-one-click-deploy

    4. Gemma 2 (Version 1) 模型所对应的行中,点击 操作,然后点击从端点取消部署模型

    5. 从端点取消部署模型对话框中,点击取消部署

    6. 点击返回按钮,返回到端点页面。

      转到“端点”页面

    7. gemma2-2b-it-mg-one-click-deploy 行末尾,点击 操作,然后选择删除端点

    8. 在确认提示中,点击确认

    gcloud

    如需使用 Google Cloud CLI 取消部署模型并删除端点,请按以下步骤操作。

    在这些命令中,替换以下内容:

    • PROJECT_ID 替换为您的项目名称
    • LOCATION_ID 替换为您在其中部署了模型和端点的区域
    • ENDPOINT_ID 替换为端点 ID
    • DEPLOYED_MODEL_NAME 替换为模型的显示名称
    • DEPLOYED_MODEL_ID 替换为模型 ID
    1. 运行 gcloud ai endpoints list 命令获取端点 ID。此命令会列出项目中所有端点的 ID。记下本教程中使用的端点的 ID。

      gcloud ai endpoints list \
          --project=PROJECT_ID \
          --region=LOCATION_ID
      

      输出类似于以下内容。在输出中,该 ID 称为 ENDPOINT_ID

      Using endpoint [https://us-central1-aiplatform.googleapis.com/]
      ENDPOINT_ID: 1234567891234567891
      DISPLAY_NAME: gemma2-2b-it-mg-one-click-deploy
      
    2. 运行 gcloud ai models describe 命令获取模型 ID。记下本教程中部署的模型的 ID。

      gcloud ai models describe DEPLOYED_MODEL_NAME \
          --project=PROJECT_ID \
          --region=LOCATION_ID
      

      简略版输出如下所示。在输出中,该 ID 称为 deployedModelId

      Using endpoint [https://us-central1-aiplatform.googleapis.com/]
      artifactUri: [URI removed]
      baseModelSource:
        modelGardenSource:
          publicModelName: publishers/google/models/gemma2
      ...
      deployedModels:
      - deployedModelId: '1234567891234567891'
        endpoint: projects/12345678912/locations/us-central1/endpoints/12345678912345
      displayName: gemma2-2b-it-12345678912345
      etag: [ETag removed]
      modelSourceInfo:
        sourceType: MODEL_GARDEN
      name: projects/123456789123/locations/us-central1/models/gemma2-2b-it-12345678912345
      ...
      
    3. 从端点取消部署模型。您需要用到通过运行上述命令获得的端点 ID 和模型 ID。

      gcloud ai endpoints undeploy-model ENDPOINT_ID \
          --project=PROJECT_ID \
          --region=LOCATION_ID \
          --deployed-model-id=DEPLOYED_MODEL_ID
      

      此命令没有任何输出。

    4. 运行 gcloud ai endpoints delete 命令以删除端点。

      gcloud ai endpoints delete ENDPOINT_ID \
          --project=PROJECT_ID \
          --region=LOCATION_ID
      

      出现提示时,输入 y 以确认。此命令没有任何输出。

    删除模型

    控制台

    1. 在 Google Cloud 控制台中,从 Vertex AI 部分进入 Model Registry 页面。

      进入 Model Registry 页面

    2. 区域下拉列表中,选择您在其中部署了模型的区域。

    3. gemma2-2b-it-1234567891234 行末尾,点击 操作

    4. 选择删除模型

      删除模型时,所有关联的模型版本和评估都会从 Google Cloud 项目中一并删除。

    5. 在确认提示中,点击删除

    gcloud

    如需使用 Google Cloud CLI 删除模型,请向 gcloud ai models delete 命令提供模型的显示名称和区域。

    gcloud ai models delete DEPLOYED_MODEL_NAME \
        --project=PROJECT_ID \
        --region=LOCATION_ID
    

    DEPLOYED_MODEL_NAME 替换为模型的显示名称。将 PROJECT_ID 替换为您的项目名称。将 LOCATION_ID 替换为您在其中部署了模型的区域。

    后续步骤