通过 JetStream 使用 GKE 中的 TPU 应用 Gemma


本指南将介绍如何通过 JetStreamMaxText 使用 Google Kubernetes Engine (GKE) 上的张量处理单元 (TPU) 应用 Gemma 大语言模型 (LLM)。在本指南中,您将 Gemma 7B 参数指令调优模型权重下载到 Cloud Storage,并使用运行 JetStream 的容器将其部署到 GKE AutopilotStandard 集群上。

如果您在 JetStream 上部署模型时需要利用 Kubernetes 功能提供的可伸缩性、弹性和成本效益,那么本指南是一个很好的起点。

背景

您可以通过 JetStream 使用 GKE 中的 TPU 应用 Gemma,从而构建一个可直接用于生产环境的强大服务解决方案,该解决方案具备托管式 Kubernetes 的所有优势,包括经济高效的可伸缩性和更高的可用性。本部分介绍本教程中使用的关键技术。

Gemma

Gemma 是一组公开提供的轻量级生成式人工智能 (AI) 模型(根据开放许可发布)。这些 AI 模型可以在应用、硬件、移动设备或托管服务中运行。您可以使用 Gemma 模型生成文本,但也可以针对专门任务对这些模型进行调优。

如需了解详情,请参阅 Gemma 文档

TPU

TPU 是 Google 定制开发的应用专用集成电路 (ASIC),用于加速机器学习和使用 TensorFlowPyTorchJAX 等框架构建的 AI 模型。

使用 GKE 中的 TPU 之前,我们建议您完成以下学习路线:

  1. 了解 Cloud TPU 系统架构中的当前 TPU 版本可用性。
  2. 了解 GKE 中的 TPU

本教程介绍如何应用 Gemma 7B 模型。GKE 在单主机 TPUv5e 节点上部署模型,并根据模型要求配置 TPU 拓扑,以低延迟提供提示。

JetStream

JetStream 是由 Google 开发的开源推理服务框架。JetStream 可以在 TPU 和 GPU 上实现高性能、高吞吐量和内存优化的推理。它提供高级性能优化(包括连续批处理和量化技术),以协助 LLM 部署。JetStream 支持 PyTorch/XLA 和 JAX TPU 服务,从而实现最佳性能。

如需详细了解这些优化,请参阅 JetStream PyTorchJetStream MaxText 项目仓库。

MaxText

MaxText是一个高性能、可扩缩且适应性强的 JAX LLM 实现,基于如下开源 JAX 仓库构建:FlaxOrbaxOptax。MaxText 的仅解码器 LLM 实现是使用 Python 编写的。它大量利用 XLA 编译器来实现高性能,而无需构建自定义内核。

如需详细了解 MaxText 支持的最新模型和参数大小,请参阅 MaxtText 项目仓库

目标

本教程适用于使用 JAX 的生成式 AI 客户、GKE 的新用户或现有用户、机器学习工程师、MLOps (DevOps) 工程师或是对使用 Kubernetes 容器编排功能应用 LLM 感兴趣的平台管理员。

本教程介绍以下步骤:

  1. 根据模型特征准备一个具有推荐 TPU 拓扑的 GKE Autopilot 或 Standard 集群。
  2. 在 GKE 上部署 JetStream 组件。
  3. 获取并发布 Gemma 7B 指令调优模型。
  4. 应用已发布的模型并与之互动。

架构

本部分介绍本教程中使用的 GKE 架构。该架构包括一个 GKE Autopilot 或 Standard 集群,该集群用于预配 TPU 和托管 JetStream 组件以部署和应用模型。

下图展示了此架构的组件:

具有单主机 TPU 节点池(其中包含 Maxengine 和 Max HTTP 组件)的 GKE 集群的架构。

此架构包括以下组件:

  • GKE Autopilot 或 Standard 区域级集群。
  • 两个托管 JetStream 部署的单主机 TPU 切片节点池。
  • Service 组件将入站流量分布到所有 JetStream HTTP 副本。
  • JetStream HTTP 是一个 HTTP 服务器,它接受封装容器形式的 JetStream 所需格式的请求并将其发送到 JetStream 的 GRPC 客户端
  • Maxengine 是一个 JetStream 服务器,该服务器通过连续批处理执行推断。

准备工作

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • Make sure that you have the following role or roles on the project: roles/container.admin, roles/iam.serviceAccountAdmin

    Check for the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. In the Principal column, find all rows that identify you or a group that you're included in. To learn which groups you're included in, contact your administrator.

    4. For all rows that specify or include you, check the Role colunn to see whether the list of roles includes the required roles.

    Grant the roles

    1. In the Google Cloud console, go to the IAM page.

      前往 IAM
    2. 选择项目。
    3. 点击 授予访问权限
    4. 新的主账号字段中,输入您的用户标识符。 这通常是 Google 账号的电子邮件地址。

    5. 选择角色列表中,选择一个角色。
    6. 如需授予其他角色,请点击 添加其他角色,然后添加其他各个角色。
    7. 点击保存
  • 确保您有足够的配额用于 8 个 TPU v5e PodSlice Lite 芯片。在本教程中,您将使用按需实例
  • 如果您还没有 Kaggle 账号,请创建一个。

获取对模型的访问权限

如需访问 Gemma 模型以部署到 GKE,您必须先签署许可同意协议。

您必须签署同意协议才能使用 Gemma。 请按照以下说明操作:

  1. 访问 Kaggle.com 上的 Gemma 模型同意页面
  2. 如果您尚未登录 Kaggle,请进行登录。
  3. 点击申请访问权限
  4. Choose Account for Consent(选择进行同意的账号)部分中,选择 Verify via Kaggle Account(通过 Kaggle 账号验证),以使用您的 Kaggle 账号进行同意。
  5. 接受模型条款及条件

生成一个访问令牌

如需通过 Kaggle 访问模型,您需要 Kaggle API 令牌。

如果您还没有令牌,请按照以下步骤生成新令牌:

  1. 在浏览器中,转到 Kaggle 设置
  2. API 部分下,点击 Create New Token(创建新令牌)。

系统将下载名为 kaggle.json 的文件。

准备环境

在本教程中,您将使用 Cloud Shell 来管理 Google Cloud 上托管的资源。Cloud Shell 预安装有本教程所需的软件,包括 kubectlgcloud CLI

如需使用 Cloud Shell 设置您的环境,请按照以下步骤操作:

  1. 在 Google Cloud 控制台中,点击 Google Cloud 控制台中的 Cloud Shell 激活图标 激活 Cloud Shell 以启动 Cloud Shell 会话。此操作会在 Google Cloud 控制台的底部窗格中启动会话。

  2. 设置默认环境变量:

    gcloud config set project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export BUCKET_NAME=BUCKET_NAME
    export REGION=REGION
    export LOCATION=LOCATION
    export CLUSTER_VERSION=CLUSTER_VERSION
    

    替换以下值:

    • PROJECT_ID:您的 Google Cloud 项目 ID
    • CLUSTER_NAME:GKE 集群的名称。
    • BUCKET_NAME:Cloud Storage 存储桶的名称。您无需指定 gs:// 前缀。
    • REGION:GKE 集群、Cloud Storage 存储桶和 TPU 节点所在的区域。该区域包含可以使用 TPU v5e 机器类型的可用区(例如 us-west1us-west4us-central1us-east1us-east5europe-west4)。对于 Autopilot 集群,请确保您有足够的 TPU v5e 可用区级资源用于所选的区域。
    • (仅限标准集群)LOCATION可以使用 TPU 资源的可用区(例如 us-west4-a)。对于 Autopilot 集群,您无需指定可用区,只需指定区域。
    • CLUSTER_VERSION:GKE 版本,必须支持您要使用的机器类型。 请注意,默认 GKE 版本可能无法为您的目标 TPU 提供可用性。 如需查看 TPU 机器类型可用的最低 GKE 版本列表,请参阅 GKE 中的 TPU 可用性

创建和配置 Google Cloud 资源

请按照以下说明创建所需的资源。

创建 GKE 集群

您可以在 GKE Autopilot 或 Standard 集群中的 TPU 上应用 Gemma。我们建议您使用 Autopilot 集群获得全托管式 Kubernetes 体验。如需选择最适合您的工作负载的 GKE 操作模式,请参阅选择 GKE 操作模式

Autopilot

在 Cloud Shell 中,运行以下命令:

gcloud container clusters create-auto ${CLUSTER_NAME} \
  --project=${PROJECT_ID} \
  --region=${REGION} \
  --cluster-version=${CLUSTER_VERSION}

标准

  1. 创建使用适用于 GKE 的工作负载身份联合的区域级 GKE Standard 集群。

    gcloud container clusters create ${CLUSTER_NAME} \
        --enable-ip-alias \
        --machine-type=e2-standard-4 \
        --num-nodes=2 \
        --cluster-version=${CLUSTER_VERSION} \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --location=${REGION}
    

    集群创建可能需要几分钟的时间。

  2. 运行以下命令来为集群创建节点池

    gcloud container node-pools create gemma-7b-tpu-nodepool \
      --cluster=${CLUSTER_NAME} \
      --machine-type=ct5lp-hightpu-8t \
      --project=${PROJECT_ID} \
      --num-nodes=2 \
      --region=${REGION} \
      --node-locations=${LOCATION}
    

    GKE 会创建一个具有 2x4 拓扑和两个节点的 TPU v5e 节点池。

创建 Cloud Storage 存储桶

在 Cloud Shell 中,运行以下命令:

gcloud storage buckets create gs://${BUCKET_NAME} --location=${REGION}

这会创建一个 Cloud Storage 存储桶来存储您从 Kaggle 下载的模型文件。

将访问令牌上传到 Cloud Shell

在 Cloud Shell 中,您可以将 Kaggle API 令牌上传到 Google Cloud 项目:

  1. 在 Cloud Shell 中,点击 更多 > 上传
  2. 选择“文件”,然后点击选择文件
  3. 打开 kaggle.json 文件。
  4. 点击上传

为 Kaggle 凭据创建 Kubernetes Secret

在 Cloud Shell 中,执行以下操作:

  1. 配置 kubectl 以与您的集群通信:

    gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION}
    
  2. 创建一个 Secret 以存储 Kaggle 凭据:

    kubectl create secret generic kaggle-secret \
        --from-file=kaggle.json
    

使用适用于 GKE 的工作负载身份联合配置工作负载访问权限

为应用分配 Kubernetes ServiceAccount,并将该 Kubernetes ServiceAccount 配置为充当 IAM 服务账号。

  1. 为您的应用创建 IAM 服务账号:

    gcloud iam service-accounts create wi-jetstream
    
  2. 为您的 IAM 服务账号添加 IAM 政策绑定以管理 Cloud Storage:

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
        --member "serviceAccount:wi-jetstream@${PROJECT_ID}.iam.gserviceaccount.com" \
        --role roles/storage.objectUser
    
    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
        --member "serviceAccount:wi-jetstream@${PROJECT_ID}.iam.gserviceaccount.com" \
        --role roles/storage.insightsCollectorService
    
  3. 通过在两个服务账号之间添加 IAM 政策绑定,允许 Kubernetes ServiceAccount 模拟 IAM 服务账号。此绑定允许 Kubernertes ServiceAccount 充当 IAM 服务账号:

    gcloud iam service-accounts add-iam-policy-binding wi-jetstream@${PROJECT_ID}.iam.gserviceaccount.com \
        --role roles/iam.workloadIdentityUser \
        --member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/default]"
    
  4. 使用 IAM 服务账号的电子邮件地址为 Kubernetes 服务账号添加注解

    kubectl annotate serviceaccount default \
        iam.gke.io/gcp-service-account=wi-jetstream@${PROJECT_ID}.iam.gserviceaccount.com
    

转换模型检查点

在本部分中,您将创建一个 Job 来执行以下操作:

  1. 从 Kaggle 下载基础 Orbax 检查点。
  2. 将检查点上传到 Cloud Storage 存储桶。
  3. 将检查点转换为与 MaxText 兼容的检查点。
  4. 取消扫描要用于传送的检查点。

部署模型检查点转换 Job

请按照以下说明下载并转换 Gemma 7B 模型检查点文件。

  1. 创建以下清单作为 job-7b.yaml

    apiVersion: batch/v1
    kind: Job
    metadata:
      name: data-loader-7b
    spec:
      ttlSecondsAfterFinished: 30
      template:
        spec:
          restartPolicy: Never
          containers:
          - name: inference-checkpoint
            image: us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.2
            args:
            - -b=BUCKET_NAME
            - -m=google/gemma/maxtext/7b-it/2
            volumeMounts:
            - mountPath: "/kaggle/"
              name: kaggle-credentials
              readOnly: true
            resources:
              requests:
                google.com/tpu: 8
              limits:
                google.com/tpu: 8
          nodeSelector:
            cloud.google.com/gke-tpu-topology: 2x4
            cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
          volumes:
          - name: kaggle-credentials
            secret:
              defaultMode: 0400
              secretName: kaggle-secret
    
  2. 应用清单:

    kubectl apply -f job-7b.yaml
    
  3. 等待安排 Job 的 Pod 开始运行:

    kubectl get pod -w
    

    输出将如下所示,此过程可能需要几分钟时间:

    NAME                  READY   STATUS              RESTARTS   AGE
    data-loader-7b-abcd   0/1     ContainerCreating   0          28s
    data-loader-7b-abcd   1/1     Running             0          51s
    

    对于 Autopilot 集群,预配所需的 TPU 资源可能需要几分钟时间。

  4. 查看来自 Job 的日志:

    kubectl logs -f jobs/data-loader-7b
    

    Job 完成后,输出类似于以下内容:

    Successfully generated decode checkpoint at: gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
    + echo -e '\nCompleted unscanning checkpoint to gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items'
    
    Completed unscanning checkpoint to gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
    

部署 JetStream

在本部分中,您将部署 JetStream 容器来应用 Gemma 模型。

请按照以下说明部署 Gemma 7B 指令调优模型。

  1. 创建以下 jetstream-gemma-deployment.yaml 清单:

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: maxengine-server
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: maxengine-server
      template:
        metadata:
          labels:
            app: maxengine-server
        spec:
          nodeSelector:
            cloud.google.com/gke-tpu-topology: 2x4
            cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
          containers:
          - name: maxengine-server
            image: us-docker.pkg.dev/cloud-tpu-images/inference/maxengine-server:v0.2.2
            args:
            - model_name=gemma-7b
            - tokenizer_path=assets/tokenizer.gemma
            - per_device_batch_size=4
            - max_prefill_predict_length=1024
            - max_target_length=2048
            - async_checkpointing=false
            - ici_fsdp_parallelism=1
            - ici_autoregressive_parallelism=-1
            - ici_tensor_parallelism=1
            - scan_layers=false
            - weight_dtype=bfloat16
            - load_parameters_path=gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
            - prometheus_port=PROMETHEUS_PORT
            ports:
            - containerPort: 9000
            resources:
              requests:
                google.com/tpu: 8
              limits:
                google.com/tpu: 8
          - name: jetstream-http
            image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.2
            ports:
            - containerPort: 8000
    ---
    apiVersion: v1
    kind: Service
    metadata:
      name: jetstream-svc
    spec:
      selector:
        app: maxengine-server
      ports:
      - protocol: TCP
        name: jetstream-http
        port: 8000
        targetPort: 8000
      - protocol: TCP
        name: jetstream-grpc
        port: 9000
        targetPort: 9000
    

    该清单设置以下关键属性:

    • tokenizer_path:模型词元化器的路径。
    • load_parameters_path:Cloud Storage 存储桶中存储检查点的路径。
    • per_device_batch_size:每个设备的解码批次大小,其中一个 TPU 芯片等于一个设备。
    • max_prefill_predict_length:进行自动回归时预填充的最大长度。
    • max_target_length:序列长度上限。
    • model_name:模型名称 (gemma-7b)。
    • ici_fsdp_parallelism:用于完全分片数据并行 (FSDP) 的分片数。
    • ici_tensor_parallelism:用于张量并行的分片数。
    • ici_autoregressive_parallelism:用于自动回归并行的分片数。
    • prometheus_port:用于公开 Prometheus 指标的端口。如果不需要指标,请移除此参数。
    • scan_layers: 扫描层布尔值标志 (boolean)。
    • weight_dtype:权重数据类型 (bfloat16)。
  2. 应用清单:

    kubectl apply -f jetstream-gemma-deployment.yaml
    
  3. 验证 Deployment:

    kubectl get deployment
    

    输出类似于以下内容:

    NAME                              READY   UP-TO-DATE   AVAILABLE   AGE
    maxengine-server                  2/2     2            2           ##s
    

    对于 Autopilot 集群,预配所需的 TPU 资源可能需要几分钟时间。

  4. 查看 HTTP 服务器日志以检查模型是否已加载和编译。服务器可能需要几分钟才能完成此操作。

    kubectl logs deploy/maxengine-server -f -c jetstream-http
    

    输出类似于以下内容:

    kubectl logs deploy/maxengine-server -f -c jetstream-http
    
    INFO:     Started server process [1]
    INFO:     Waiting for application startup.
    INFO:     Application startup complete.
    INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
    
  5. 查看 MaxEngine 日志并验证编译是否已完成。

    kubectl logs deploy/maxengine-server -f -c maxengine-server
    

    输出类似于以下内容:

    2024-03-29 17:09:08,047 - jax._src.dispatch - DEBUG - Finished XLA compilation of jit(initialize) in 0.26236414909362793 sec
    2024-03-29 17:09:08,150 - root - INFO - ---------Generate params 0 loaded.---------
    

应用模型

在本部分中,您将与模型互动。

设置端口转发

您可以通过在上一步中创建的 ClusterIP Service 访问 JetStream Deployment。只能从集群内部访问 ClusterIP Service。因此,如需从集群外部访问该 Service,请完成以下步骤:

如需建立端口转发会话,请运行以下命令:

kubectl port-forward svc/jetstream-svc 8000:8000

使用 curl 与模型互动

  1. 通过打开新终端并运行以下命令,验证您是否可以访问 JetStream HTTP 服务器:

    curl --request POST \
    --header "Content-type: application/json" \
    -s \
    localhost:8000/generate \
    --data \
    '{
        "prompt": "What are the top 5 programming languages",
        "max_tokens": 200
    }'
    

    由于模型预热,初始请求可能需要几秒钟才能完成。 输出类似于以下内容:

    {
        "response": "\nfor data science in 2023?\n\n**1. Python:**\n- Widely used for data science due to its simplicity, readability, and extensive libraries for data wrangling, analysis, visualization, and machine learning.\n- Popular libraries include pandas, scikit-learn, and matplotlib.\n\n**2. R:**\n- Statistical programming language widely used for data analysis, visualization, and modeling.\n- Popular libraries include ggplot2, dplyr, and caret.\n\n**3. Java:**\n- Enterprise-grade language with strong performance and scalability.\n- Popular libraries include Spark, TensorFlow, and Weka.\n\n**4. C++:**\n- High-performance language often used for data analytics and machine learning models.\n- Popular libraries include TensorFlow, PyTorch, and OpenCV.\n\n**5. SQL:**\n- Relational database language essential for data wrangling and querying large datasets.\n- Popular tools"
    }
    

(可选)通过 Gradio 聊天界面与模型互动

在本部分中,您将构建一个网页聊天应用,可让您与指令调优模型互动。

Gradio 是一个 Python 库,它具有一个可为聊天机器人创建界面的 ChatInterface 封装容器。

部署聊天界面

  1. 在 Cloud Shell 中,将以下清单保存为 gradio.yaml

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: gradio
      labels:
        app: gradio
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: gradio
      template:
        metadata:
          labels:
            app: gradio
        spec:
          containers:
          - name: gradio
            image: us-docker.pkg.dev/google-samples/containers/gke/gradio-app:v1.0.3
            resources:
              requests:
                cpu: "512m"
                memory: "512Mi"
              limits:
                cpu: "1"
                memory: "512Mi"
            env:
            - name: CONTEXT_PATH
              value: "/generate"
            - name: HOST
              value: "http://jetstream-http-svc:8000"
            - name: LLM_ENGINE
              value: "max"
            - name: MODEL_ID
              value: "gemma"
            - name: USER_PROMPT
              value: "<start_of_turn>user\nprompt<end_of_turn>\n"
            - name: SYSTEM_PROMPT
              value: "<start_of_turn>model\nprompt<end_of_turn>\n"
            ports:
            - containerPort: 7860
    ---
    apiVersion: v1
    kind: Service
    metadata:
      name: gradio
    spec:
      selector:
        app: gradio
      ports:
        - protocol: TCP
          port: 8080
          targetPort: 7860
      type: ClusterIP
    
  2. 应用清单:

    kubectl apply -f gradio.yaml
    
  3. 等待部署成为可用状态:

    kubectl wait --for=condition=Available --timeout=300s deployment/gradio
    

使用聊天界面

  1. 在 Cloud Shell 中,运行以下命令:

    kubectl port-forward service/gradio 8080:8080
    

    这会创建从 Cloud Shell 到 Gradio 服务的端口转发。

  2. 点击 Cloud Shell 任务栏右上角的 “网页预览”图标 网页预览按钮。点击在端口 8080 上预览。浏览器中会打开一个新的标签页。

  3. 使用 Gradio 聊天界面与 Gemma 互动。添加提示,然后点击提交

问题排查

  • 如果您收到 Empty reply from server 消息,则容器可能尚未完成模型数据下载。再次检查 Pod 的日志中是否包含 Connected 消息,该消息表明模型已准备好进行应用。
  • 如果您看到 Connection refused,请验证您的端口转发已启用

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

删除已部署的资源

为避免系统因您在本指南中创建的资源而向您的 Google Cloud 账号收取费用,请运行以下命令并按照提示进行操作:

gcloud container clusters delete ${CLUSTER_NAME} --region=${REGION}

gcloud iam service-accounts delete wi-jetstream@PROJECT_ID.iam.gserviceaccount.com

gcloud storage rm --recursive gs://BUCKET_NAME

后续步骤