通过 JetStream 和 PyTorch 使用 GKE 上的 TPU 应用 LLM


本指南介绍了如何通过 JetStreamPyTorch 使用 Google Kubernetes Engine (GKE) 上的张量处理单元 (TPU) 应用大语言模型 (LLM)。在本指南中,您将模型权重下载到 Cloud Storage,然后使用运行 JetStream 的容器将其部署到 GKE AutopilotStandard 集群上。

如果您在 JetStream 上部署模型时需要利用 Kubernetes 功能提供的可伸缩性、弹性和成本效益,那么本指南是一个很好的起点。

本指南适用于使用 PyTorch 的生成式 AI 客户、GKE 的新用户或现有用户、机器学习工程师、MLOps (DevOps) 工程师或者对使用 Kubernetes 容器编排功能应用 LLM 感兴趣的平台管理员。

背景

通过 JetStream 使用 GKE 上的 TPU 应用 LLM,您可以构建一个可用于生产用途的强大服务解决方案,具备托管式 Kubernetes 的所有优势,包括经济高效、可伸缩性和更高的可用性。本部分介绍本教程中使用的关键技术。

TPU 简介

TPU 是 Google 定制开发的应用专用集成电路 (ASIC),用于加速机器学习和使用 TensorFlowPyTorchJAX 等框架构建的 AI 模型。

使用 GKE 中的 TPU 之前,我们建议您完成以下学习路线:

  1. 了解 Cloud TPU 系统架构中的当前 TPU 版本可用性。
  2. 了解 GKE 中的 TPU

本教程介绍如何应用各种 LLM 模型。GKE 在单主机 TPUv5e 节点上部署模型,并根据模型要求配置 TPU 拓扑,以低延迟提供提示。

JetStream 简介

JetStream 是由 Google 开发的开源推理服务框架。JetStream 可以在 TPU 和 GPU 上实现高性能、高吞吐量和内存优化的推理。JetStream 提供高级性能优化(包括连续批处理、KV 缓存优化和量化技术),以协助 LLM 部署。JetStream 支持 PyTorch/XLA 和 JAX TPU 服务,从而实现最佳性能。

连续批处理

连续批处理是一种将传入的推理请求动态分组为多个批次的技术,从而缩短延迟时间并提高吞吐量。

KV 缓存量化

KV 缓存量化涉及压缩注意力机制中使用的键值对缓存,从而降低内存需求。

Int8 权重量化

Int8 权重量化可将模型权重的精确率从 32 位浮点数降低到 8 位整数,从而加快计算速度并减少内存用量。

如需详细了解这些优化,请参阅 JetStream PyTorchJetStream MaxText 项目仓库。

PyTorch 简介

PyTorch 是由 Meta 开发的开源机器学习框架,现在是 Linux Foundation 综合框架的一部分。PyTorch 提供张量计算和深度神经网络等高级功能。

目标

  1. 根据模型特征准备一个具有推荐 TPU 拓扑的 GKE Autopilot 或 Standard 集群。
  2. 在 GKE 上部署 JetStream 组件。
  3. 获取并发布模型。
  4. 应用已发布的模型并与之互动。

架构

本部分介绍本教程中使用的 GKE 架构。该架构包括一个 GKE Autopilot 或 Standard 集群,该集群用于预配 TPU 和托管 JetStream 组件以部署和应用模型。

下图展示了此架构的组件:

具有单主机 TPU 节点池(其中包含 JetStream-PyTorch 和 JetStream HTTP 组件)的 GKE 集群的架构。

此架构包括以下组件:

  • GKE Autopilot 或 Standard 区域级集群。
  • 两个托管 JetStream 部署的单主机 TPU 切片节点池。
  • Service 组件将入站流量分布到所有 JetStream HTTP 副本。
  • JetStream HTTP 是一个 HTTP 服务器,它接受封装容器形式的 JetStream 所需格式的请求并将其发送到 JetStream 的 GRPC 客户端
  • JetStream-PyTorch 是一个 JetStream 服务器,该服务器通过连续批处理执行推断。

准备工作

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • Make sure that you have the following role or roles on the project: roles/container.admin, roles/iam.serviceAccountAdmin

    Check for the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. In the Principal column, find all rows that identify you or a group that you're included in. To learn which groups you're included in, contact your administrator.

    4. For all rows that specify or include you, check the Role colunn to see whether the list of roles includes the required roles.

    Grant the roles

    1. In the Google Cloud console, go to the IAM page.

      进入 IAM
    2. 选择项目。
    3. 点击 授予访问权限
    4. 新的主账号字段中,输入您的用户标识符。 这通常是 Google 账号的电子邮件地址。

    5. 选择角色列表中,选择一个角色。
    6. 如需授予其他角色,请点击 添加其他角色,然后添加其他各个角色。
    7. 点击 Save(保存)。
  • 确保您有足够的配额用于 8 个 TPU v5e PodSlice Lite 芯片。在本教程中,您将使用按需实例
  • 如果您还没有 Hugging Face 令牌,请创建一个。

获取对模型的访问权限

获取对 Hugging Face 上的各种模型的访问权限以部署到 GKE

Gemma 7B-it

如需访问 Gemma 模型以部署到 GKE,您必须先签署许可同意协议。

  1. 访问 Hugging Face 上的 Gemma 模型同意页面
  2. 如果您尚未登录 Hugging Face,请进行登录。
  3. 查看并接受模型条款及条件

Llama 3 8B

如需获取对 Llama 3 模型的访问权限以部署到 GKE,您必须先签署许可同意协议。

  1. 访问 Hugging Face 上的 Llama 3 模型同意页面
  2. 如果您尚未登录 Hugging Face,请进行登录。
  3. 查看并接受模型条款及条件

准备环境

在本教程中,您将使用 Cloud Shell 来管理 Google Cloud 上托管的资源。Cloud Shell 预安装有本教程所需的软件,包括 kubectlgcloud CLI

如需使用 Cloud Shell 设置您的环境,请按照以下步骤操作:

  1. 在 Google Cloud 控制台中,点击 Google Cloud 控制台中的 Cloud Shell 激活图标 激活 Cloud Shell 以启动 Cloud Shell 会话。此操作会在 Google Cloud 控制台的底部窗格中启动会话。

  2. 设置默认环境变量:

    gcloud config set project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export BUCKET_NAME=BUCKET_NAME
    export REGION=REGION
    export LOCATION=LOCATION
    export CLUSTER_VERSION=CLUSTER_VERSION
    

    替换以下值:

    • PROJECT_ID:您的 Google Cloud 项目 ID
    • CLUSTER_NAME:GKE 集群的名称。
    • BUCKET_NAME:Cloud Storage 存储桶的名称。您无需指定 gs:// 前缀。
    • REGION:GKE 集群、Cloud Storage 存储桶和 TPU 节点所在的区域。该区域包含可以使用 TPU v5e 机器类型的可用区(例如 us-west1us-west4us-central1us-east1us-east5europe-west4)。对于 Autopilot 集群,请确保您有足够的 TPU v5e 可用区级资源用于所选的区域。
    • (仅限标准集群)LOCATION可以使用 TPU 资源的可用区(例如 us-west4-a)。对于 Autopilot 集群,您无需指定可用区,只需指定区域。
    • CLUSTER_VERSION:GKE 版本,必须支持您要使用的机器类型。请注意,默认 GKE 版本可能不适用于目标 TPU。如需查看按 TPU 机器类型提供的最低 GKE 版本列表,请参阅 GKE 中的 TPU 可用性

创建和配置 Google Cloud 资源

请按照以下说明创建所需的资源。

创建 GKE 集群

您可以在 GKE Autopilot 或 Standard 集群中的 TPU 上应用 Gemma。我们建议您使用 Autopilot 集群获得全托管式 Kubernetes 体验。如需选择最适合您的工作负载的 GKE 操作模式,请参阅选择 GKE 操作模式

Autopilot

创建 Autopilot GKE 集群:

gcloud container clusters create-auto CLUSTER_NAME \
    --project=PROJECT_ID \
    --region=REGION \
    --cluster-version=CLUSTER_VERSION

标准

  1. 创建使用适用于 GKE 的工作负载身份联合的区域级 GKE Standard 集群:

    gcloud container clusters create CLUSTER_NAME \
        --enable-ip-alias \
        --machine-type=e2-standard-4 \
        --num-nodes=2 \
        --cluster-version=CLUSTER_VERSION \
        --workload-pool=PROJECT_ID.svc.id.goog \
        --location=REGION
    

    集群创建可能需要几分钟的时间。

  2. 创建具有 2x4 拓扑和两个节点的 TPU v5e 节点池

    gcloud container node-pools create tpu-nodepool \
      --cluster=CLUSTER_NAME \
      --machine-type=ct5lp-hightpu-8t \
      --project=PROJECT_ID \
      --num-nodes=2 \
      --region=REGION \
      --node-locations=LOCATION
    

创建 Cloud Storage 存储桶

创建一个 Cloud Storage 存储桶来存储转换后的检查点:

gcloud storage buckets create gs://BUCKET_NAME --location=REGION

在 Cloud Shell 中生成 Hugging Face CLI 令牌

如果您还没有 Hugging Face 令牌,请生成一个新令牌:

  1. 点击您的个人资料 > 设置 > 访问令牌
  2. 点击新建令牌
  3. 指定您选择的名称和一个至少为 Read 的角色。
  4. 点击 Generate a token(生成令牌)。
  5. 修改对访问令牌的权限,以拥有对模型的 Hugging Face 仓库的读取权限。
  6. 将生成的令牌复制到剪贴板。

为 Hugging Face 凭据创建 Kubernetes Secret

在 Cloud Shell 中,执行以下操作:

  1. 配置 kubectl 以与您的集群通信:

    gcloud container clusters get-credentials CLUSTER_NAME --location=REGION
    
  2. 创建一个 Secret 来存储 Hugging Face 凭据:

    kubectl create secret generic huggingface-secret \
        --from-literal=HUGGINGFACE_TOKEN=HUGGINGFACE_TOKEN
    

    HUGGINGFACE_TOKEN 替换为您的 Hugging Face 令牌。

使用适用于 GKE 的工作负载身份联合配置工作负载访问权限

为应用分配 Kubernetes ServiceAccount,并将该 Kubernetes ServiceAccount 配置为充当 IAM 服务账号。

  1. 为您的应用创建 IAM 服务账号:

    gcloud iam service-accounts create wi-jetstream
    
  2. 为您的 IAM 服务账号添加 IAM 政策绑定以管理 Cloud Storage:

    gcloud projects add-iam-policy-binding PROJECT_ID \
        --member "serviceAccount:wi-jetstream@PROJECT_ID.iam.gserviceaccount.com" \
        --role roles/storage.objectUser
    
    gcloud projects add-iam-policy-binding PROJECT_ID \
        --member "serviceAccount:wi-jetstream@PROJECT_ID.iam.gserviceaccount.com" \
        --role roles/storage.insightsCollectorService
    
  3. 通过在两个服务账号之间添加 IAM 政策绑定,允许 Kubernetes ServiceAccount 模拟 IAM 服务账号。此绑定允许 Kubernertes ServiceAccount 充当 IAM 服务账号:

    gcloud iam service-accounts add-iam-policy-binding wi-jetstream@PROJECT_ID.iam.gserviceaccount.com \
        --role roles/iam.workloadIdentityUser \
        --member "serviceAccount:PROJECT_ID.svc.id.goog[default/default]"
    
  4. 使用 IAM 服务账号的电子邮件地址为 Kubernetes 服务账号添加注解

    kubectl annotate serviceaccount default \
        iam.gke.io/gcp-service-account=wi-jetstream@PROJECT_ID.iam.gserviceaccount.com
    

转换模型检查点

在本部分中,您将创建一个 Job 来执行以下操作:

  1. 将基础检查点从 Hugging Face 下载到本地目录。
  2. 将检查点转换为与 JetStream-Pytorch 兼容的检查点。
  3. 将检查点上传到 Cloud Storage 存储桶。

部署模型检查点转换 Job

Gemma 7B-it

下载并转换 Gemma 7B 模型检查点文件:

  1. 将以下清单保存为 job-checkpoint-converter.yaml

    apiVersion: batch/v1
    kind: Job
    metadata:
      name: checkpoint-converter
    spec:
      ttlSecondsAfterFinished: 30
      template:
        spec:
          restartPolicy: Never
          containers:
          - name: inference-checkpoint
            image: us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.3
            args:
            - -s=jetstream-pytorch
            - -m=google/gemma-7b-it-pytorch
            - -o=gs://BUCKET_NAME/pytorch/gemma-7b-it/final/bf16/
            - -n=gemma
            - -q=False
            - -h=True
            volumeMounts:
            - mountPath: "/huggingface/"
              name: huggingface-credentials
              readOnly: true
            resources:
              requests:
                google.com/tpu: 8
              limits:
                google.com/tpu: 8
          nodeSelector:
            cloud.google.com/gke-tpu-topology: 2x4
            cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
          volumes:
          - name: huggingface-credentials
            secret:
              defaultMode: 0400
              secretName: huggingface-secret

Llama 3 8B

下载并转换 Llama 3 8B 模型检查点文件:

  1. 将以下清单保存为 job-checkpoint-converter.yaml

    apiVersion: batch/v1
    kind: Job
    metadata:
      name: checkpoint-converter
    spec:
      ttlSecondsAfterFinished: 30
      template:
        spec:
          restartPolicy: Never
          containers:
          - name: inference-checkpoint
            image: us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.3
            args:
            - -s=jetstream-pytorch
            - -m=meta-llama/Meta-Llama-3-8B
            - -o=gs://BUCKET_NAME/pytorch/llama-3-8b/final/bf16/
            - -n=llama-3
            - -q=False
            - -h=True
            volumeMounts:
            - mountPath: "/huggingface/"
              name: huggingface-credentials
              readOnly: true
            resources:
              requests:
                google.com/tpu: 8
              limits:
                google.com/tpu: 8
          nodeSelector:
            cloud.google.com/gke-tpu-topology: 2x4
            cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
          volumes:
          - name: huggingface-credentials
            secret:
              defaultMode: 0400
              secretName: huggingface-secret
  1. BUCKET_NAME 替换为您之前创建的 GSBucket:

    sed -i "s|BUCKET_NAME|BUCKET_NAME|g" job-checkpoint-converter.yaml
    
  2. 应用清单:

    kubectl apply -f job-checkpoint-converter.yaml
    
  3. 等待安排 Job 的 Pod 开始运行:

    kubectl get pod -w
    

    输出将如下所示,此过程可能需要几分钟时间:

    NAME                        READY   STATUS              RESTARTS   AGE
    checkpoint-converter-abcd   0/1     ContainerCreating   0          28s
    checkpoint-converter-abcd   1/1     Running             0          51s
    

    对于 Autopilot 集群,预配所需的 TPU 资源可能需要几分钟时间。

  4. 通过查看 Job 的日志来验证 Job 是否已完成:

    kubectl logs -f jobs/checkpoint-converter
    

    Job 完成后,输出类似于以下内容:

    Completed uploading converted checkpoint from local path /pt-ckpt/ to GSBucket gs://BUCKET_NAME/pytorch/<model_name>/final/bf16/"
    

部署 JetStream

部署 JetStream 容器以应用模型:

将以下清单保存为 jetstream-pytorch-deployment.yaml

Gemma 7B-it

apiVersion: apps/v1
kind: Deployment
metadata:
  name: jetstream-pytorch-server
spec:
  replicas: 2
  selector:
    matchLabels:
      app: jetstream-pytorch-server
  template:
    metadata:
      labels:
        app: jetstream-pytorch-server
      annotations:
        gke-gcsfuse/volumes: "true"
    spec:
      nodeSelector:
        cloud.google.com/gke-tpu-topology: 2x4
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
      containers:
      - name: jetstream-pytorch-server
        image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pytorch-server:v0.2.3
        args:
        - --size=7b
        - --model_name=gemma
        - --batch_size=32
        - --max_cache_length=2048
        - --quantize_weights=False
        - --quantize_kv_cache=False
        - --tokenizer_path=/models/pytorch/gemma-7b-it/final/bf16/tokenizer.model
        - --checkpoint_path=/models/pytorch/gemma-7b-it/final/bf16/model.safetensors
        ports:
        - containerPort: 9000
        volumeMounts:
        - name: gcs-fuse-checkpoint
          mountPath: /models
          readOnly: true
        resources:
          requests:
            google.com/tpu: 8
          limits:
            google.com/tpu: 8
      - name: jetstream-http
        image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.2
        ports:
        - containerPort: 8000
      volumes:
      - name: gcs-fuse-checkpoint
        csi:
          driver: gcsfuse.csi.storage.gke.io
          readOnly: true
          volumeAttributes:
            bucketName: BUCKET_NAME
            mountOptions: "implicit-dirs"
---
apiVersion: v1
kind: Service
metadata:
  name: jetstream-svc
spec:
  selector:
    app: jetstream-pytorch-server
  ports:
  - protocol: TCP
    name: jetstream-http
    port: 8000
    targetPort: 8000

Llama 3 8B

apiVersion: apps/v1
kind: Deployment
metadata:
  name: jetstream-pytorch-server
spec:
  replicas: 2
  selector:
    matchLabels:
      app: jetstream-pytorch-server
  template:
    metadata:
      labels:
        app: jetstream-pytorch-server
      annotations:
        gke-gcsfuse/volumes: "true"
    spec:
      nodeSelector:
        cloud.google.com/gke-tpu-topology: 2x4
        cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
      containers:
      - name: jetstream-pytorch-server
        image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pytorch-server:v0.2.3
        args:
        - --size=8b
        - --model_name=llama-3
        - --batch_size=32
        - --max_cache_length=2048
        - --quantize_weights=False
        - --quantize_kv_cache=False
        - --tokenizer_path=/models/pytorch/llama-3-8b/final/bf16/tokenizer.model
        - --checkpoint_path=/models/pytorch/llama-3-8b/final/bf16/model.safetensors
        ports:
        - containerPort: 9000
        volumeMounts:
        - name: gcs-fuse-checkpoint
          mountPath: /models
          readOnly: true
        resources:
          requests:
            google.com/tpu: 8
          limits:
            google.com/tpu: 8
      - name: jetstream-http
        image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.2
        ports:
        - containerPort: 8000
      volumes:
      - name: gcs-fuse-checkpoint
        csi:
          driver: gcsfuse.csi.storage.gke.io
          readOnly: true
          volumeAttributes:
            bucketName: BUCKET_NAME
            mountOptions: "implicit-dirs"
---
apiVersion: v1
kind: Service
metadata:
  name: jetstream-svc
spec:
  selector:
    app: jetstream-pytorch-server
  ports:
  - protocol: TCP
    name: jetstream-http
    port: 8000
    targetPort: 8000

该清单设置以下关键属性:

  • size:您的模型的大小。
  • model_name:模型名称(gemmallama-3)。
  • batch_size:每个设备的解码批次大小,其中一个 TPU 芯片等于一个设备。
  • max_cache_length:kv 缓存的长度上限。
  • quantize_weights:检查点是否已量化。
  • quantize_kv_cache:kv 缓存是否已量化。
  • tokenizer_path:模型词元化器文件的路径。
  • checkpoint_path:检查点的路径。
  1. BUCKET_NAME 替换为您之前创建的 GSBucket:

    sed -i "s|BUCKET_NAME|BUCKET_NAME|g" jetstream-pytorch-deployment.yaml
    
  2. 应用清单:

    kubectl apply -f jetstream-pytorch-deployment.yaml
    
  3. 验证 Deployment:

    kubectl get deployment
    

    输出类似于以下内容:

    NAME                              READY   UP-TO-DATE   AVAILABLE   AGE
    jetstream-pytorch-server          2/2     2            2           ##s
    

    对于 Autopilot 集群,预配所需的 TPU 资源可能需要几分钟时间。

  4. 查看 HTTP 服务器日志以检查模型是否已加载和编译。服务器可能需要几分钟才能完成此操作。

    kubectl logs deploy/jetstream-pytorch-server -f -c jetstream-http
    

    输出类似于以下内容:

    kubectl logs deploy/jetstream-pytorch-server -f -c jetstream-http
    
    INFO:     Started server process [1]
    INFO:     Waiting for application startup.
    INFO:     Application startup complete.
    INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
    
  5. 查看 JetStream-PyTorch 服务器日志并验证编译是否已完成:

    kubectl logs deploy/jetstream-pytorch-server -f -c jetstream-pytorch-server
    

    输出类似于以下内容:

    Started jetstream_server....
    2024-04-12 04:33:37,128 - root - INFO - ---------Generate params 0 loaded.---------
    

应用模型

在本部分中,您将与模型互动。

设置端口转发

您可以通过在上一步中创建的 ClusterIP Service 访问 JetStream Deployment。只能从集群内部访问 ClusterIP Service。因此,如需从集群外部访问该 Service,请完成以下步骤:

如需建立端口转发会话,请运行以下命令:

kubectl port-forward svc/jetstream-svc 8000:8000

使用 curl 与模型互动

  1. 通过打开新终端并运行以下命令,验证您是否可以访问 JetStream HTTP 服务器:

    curl --request POST \
    --header "Content-type: application/json" \
    -s \
    localhost:8000/generate \
    --data \
    '{
        "prompt": "What are the top 5 programming languages",
        "max_tokens": 200
    }'
    

    由于模型预热,初始请求可能需要几秒钟才能完成。 输出类似于以下内容:

    {
        "response": " for data science in 2023?\n\n**1. Python:**\n- Widely used for data science due to its readability, extensive libraries (pandas, scikit-learn), and integration with other tools.\n- High demand for Python programmers in data science roles.\n\n**2. R:**\n- Popular choice for data analysis and visualization, particularly in academia and research.\n- Extensive libraries for statistical modeling and data wrangling.\n\n**3. Java:**\n- Enterprise-grade platform for data science, with strong performance and scalability.\n- Widely used in data mining and big data analytics.\n\n**4. SQL:**\n- Essential for data querying and manipulation, especially in relational databases.\n- Used for data analysis and visualization in various industries.\n\n**5. Scala:**\n- Scalable and efficient for big data processing and machine learning models.\n- Popular in data science for its parallelism and integration with Spark and Spark MLlib."
    }
    
    

问题排查

  • 如果您收到 Empty reply from server 消息,则容器可能尚未完成模型数据下载。再次检查 Pod 的日志中是否包含 Connected 消息,该消息表明模型已准备好进行应用。
  • 如果您看到 Connection refused,请验证您的端口转发已启用

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

删除已部署的资源

为避免系统因您在本指南中创建的资源而向您的 Google Cloud 账号收取费用,请运行以下命令并按照提示进行操作:

gcloud container clusters delete CLUSTER_NAME --region=REGION

gcloud iam service-accounts delete wi-jetstream@PROJECT_ID.iam.gserviceaccount.com

gcloud storage rm --recursive gs://BUCKET_NAME

后续步骤