使用 GKE 中的多主机 TPU 和 Saxml 来提供 LLM


本教程介绍如何通过 Saxml,在 Google Kubernetes Engine (GKE) 上使用多主机 TPU 切片节点池部署和应用大语言模型 (LLM),以实现高效的可伸缩架构。

背景

Saxml 是一个实验性系统,应用 PaxmlJAXPyTorch 框架。您可以采用这些框架,使用 TPU 来加速数据处理。为了演示 GKE 中 TPU 的部署,本教程应用了 175B LmCloudSpmd175B32Test 测试模型。GKE 分别在两个具有 4x8 拓扑的 v5e TPU 切片节点池上部署此测试模型。

为了正确部署测试模型,根据模型的大小定义了 TPU 拓扑。鉴于 Nx10 亿 16 位模型大约需要 2 倍 (2xN) GB 的内存,因此 175B LmCloudSpmd175B32Test 模型需要大约 350 GB 的内存。TPU v5e 单个 TPU 芯片具有 16 GB。为了支持 350 GB,GKE 需要 21 个 v5e TPU 芯片 (350/16= 21)。根据 TPU 配置的映射,本教程的正确 TPU 配置如下:

  • 机器类型:ct5lp-hightpu-4t
  • 拓扑:4x8(32 个 TPU 芯片)

在 GKE 中部署 TPU 时,请务必选择正确的 TPU 拓扑来应用模型。如需了解详情,请参阅规划 TPU 配置

目标

本教程适用于希望使用 GKE 编排功能来应用数据模型的 MLOps 或 DevOps 工程师或平台管理员。

本教程介绍以下步骤:

  1. 使用一个 GKE Standard 集群准备环境。该集群包含两个具有 4x8 拓扑的 v5e TPU 切片节点池。
  2. 部署 Saxml。Saxml 需要一个管理员服务器、一组用作模型服务器的 Pod、一个预建的 HTTP 服务器和一个负载均衡器。
  3. 使用 Saxml 应用 LLM。

下图展示了后续教程实现的架构:

GKE 上的多主机 TPU 的架构。
:GKE 上的多主机 TPU 的示例架构。

准备工作

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • Make sure that you have the following role or roles on the project: roles/container.admin, roles/iam.serviceAccountAdmin

    Check for the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. In the Principal column, find all rows that identify you or a group that you're included in. To learn which groups you're included in, contact your administrator.

    4. For all rows that specify or include you, check the Role colunn to see whether the list of roles includes the required roles.

    Grant the roles

    1. In the Google Cloud console, go to the IAM page.

      前往 IAM
    2. 选择项目。
    3. 点击 授予访问权限
    4. 新的主账号字段中,输入您的用户标识符。 这通常是 Google 账号的电子邮件地址。

    5. 选择角色列表中,选择一个角色。
    6. 如需授予其他角色,请点击 添加其他角色,然后添加其他各个角色。
    7. 点击保存

准备环境

  1. 在 Google Cloud 控制台中,启动 Cloud Shell 实例:
    打开 Cloud Shell

  2. 设置默认环境变量:

      gcloud config set project PROJECT_ID
      export PROJECT_ID=$(gcloud config get project)
      export REGION=COMPUTE_REGION
      export ZONE=COMPUTE_ZONE
      export GSBUCKET=PROJECT_ID-gke-bucket
    

    替换以下值:

创建 GKE Standard 集群

使用 Cloud Shell 执行以下操作:

  1. 创建使用适用于 GKE 的工作负载身份联合的 Standard 集群:

    gcloud container clusters create saxml \
        --zone=${ZONE} \
        --workload-pool=${PROJECT_ID}.svc.id.goog \
        --cluster-version=VERSION \
        --num-nodes=4
    

    VERSION 替换为 GKE 版本号。GKE 在 1.27.2-gke.2100 及更高版本中支持 TPU v5e。如需了解详情,请参阅 GKE 中的 TPU 可用性

    集群创建可能需要几分钟的时间。

  2. 创建第一个节点池,名为 tpu1

    gcloud container node-pools create tpu1 \
        --zone=${ZONE} \
        --num-nodes=8 \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=4x8 \
        --cluster=saxml
    
  3. 创建第二个节点池,名为 tpu2

    gcloud container node-pools create tpu2 \
        --zone=${ZONE} \
        --num-nodes=8 \
        --machine-type=ct5lp-hightpu-4t \
        --tpu-topology=4x8 \
        --cluster=saxml
    

您已创建以下资源:

  • 具有四个 CPU 节点的 Standard 集群。
  • 两个具有 4x8 拓扑的 v5e TPU 切片节点池。每个节点池代表 8 个 TPU 切片节点,这些节点各自具有 4 个 TPU 芯片。

必须在至少具有 4x8 拓扑切片(32 个 v5e TPU 芯片)的多主机 v5e TPU 切片上应用 175B 模型。

创建 Cloud Storage 存储桶

创建 Cloud Storage 存储桶以存储 Saxml 管理员服务器配置。正在运行的管理员服务器会定期保存其状态和已发布模型的详细信息。

在 Cloud Shell 中,运行以下命令:

gcloud storage buckets create gs://${GSBUCKET}

使用适用于 GKE 的工作负载身份联合配置工作负载访问权限

为应用分配 Kubernetes ServiceAccount,并将该 Kubernetes ServiceAccount 配置为充当 IAM 服务账号。

  1. 配置 kubectl 以与您的集群通信:

    gcloud container clusters get-credentials saxml --zone=${ZONE}
    
  2. 为您的应用创建 Kubernetes 服务账号:

    kubectl create serviceaccount sax-sa --namespace default
    
  3. 为您的应用创建 IAM 服务账号:

    gcloud iam service-accounts create sax-iam-sa
    
  4. 为您的 IAM 服务账号添加 IAM 政策绑定,以便对 Cloud Storage 执行读写操作:

    gcloud projects add-iam-policy-binding ${PROJECT_ID} \
      --member "serviceAccount:sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \
      --role roles/storage.admin
    
  5. 通过在两个服务账号之间添加 IAM 政策绑定,允许 Kubernetes 服务账号模拟 IAM 服务账号。此绑定允许 Kubernetes 服务账号充当 IAM 服务账号,以便 Kubernetes 服务账号可以对 Cloud Storage 执行读写操作。

    gcloud iam service-accounts add-iam-policy-binding sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com \
      --role roles/iam.workloadIdentityUser \
      --member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/sax-sa]"
    
  6. 使用 IAM 服务账号的电子邮件地址为 Kubernetes 服务账号添加注解。这样,您的示例应用便知道要用于访问 Google Cloud 服务的服务账号。因此,在应用要使用任何标准 Google API 客户端库访问 Google Cloud 服务时,便会使用该 IAM 服务账号。

    kubectl annotate serviceaccount sax-sa \
      iam.gke.io/gcp-service-account=sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
    

部署 Saxml

在本部分中,您将部署 Saxml 管理服务器和 Saxml 模型服务器。

部署 Saxml 管理服务器

  1. 创建以下 sax-admin-server.yaml 清单:

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: sax-admin-server
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: sax-admin-server
      template:
        metadata:
          labels:
            app: sax-admin-server
        spec:
          hostNetwork: false
          serviceAccountName: sax-sa
          containers:
          - name: sax-admin-server
            image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server:v1.1.0
            securityContext:
              privileged: true
            ports:
            - containerPort: 10000
            env:
            - name: GSBUCKET
              value: BUCKET_NAME

    BUCKET_NAME 替换为您的 Cloud Storage 存储桶名称。

  2. 应用清单:

    kubectl apply -f sax-admin-server.yaml
    
  3. 验证管理员服务器 Pod 是否已启动并运行:

    kubectl get deployment
    

    输出类似于以下内容:

    NAME               READY   UP-TO-DATE   AVAILABLE   AGE
    sax-admin-server   1/1     1            1           52s
    

部署 Saxml 模型服务器

在多主机 TPU 切片中运行的工作负载要求每个 Pod 都有一个稳定的网络标识符,以发现同一 TPU 切片中的对等方。如需定义这些标识符,请使用 IndexedJobStatefulSet 及无头 Service 或 JobSet(它会自动为属于 JobSet 的所有作业创建无头 Service)。以下部分介绍如何使用 JobSet 管理多组模型服务器 Pod。

  1. 安装 JobSet v0.2.3 或更高版本。

    kubectl apply --server-side -f https://github.com/kubernetes-sigs/jobset/releases/download/JOBSET_VERSION/manifests.yaml
    

    JOBSET_VERSION 替换为 JobSet 版本。例如 v0.2.3

  2. 验证 JobSet 控制器是否在 jobset-system 命名空间中运行:

    kubectl get pod -n jobset-system
    

    输出类似于以下内容:

    NAME                                        READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-69449d86bc-hp5r6   2/2     Running   0          2m15s
    
  3. 在两个 TPU 切片节点池中部署两个模型服务器。保存以下 sax-model-server-set 清单:

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: sax-model-server-set
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: sax-model-server
          replicas: 2
          template:
            spec:
              parallelism: 8
              completions: 8
              backoffLimit: 0
              template:
                spec:
                  serviceAccountName: sax-sa
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
                    cloud.google.com/gke-tpu-topology: 4x8
                  containers:
                  - name: sax-model-server
                    image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server:v1.1.0
                    args: ["--port=10001","--sax_cell=/sax/test", "--platform_chip=tpuv5e"]
                    ports:
                    - containerPort: 10001
                    - containerPort: 8471
                    securityContext:
                      privileged: true
                    env:
                    - name: SAX_ROOT
                      value: "gs://BUCKET_NAME/sax-root"
                    - name: MEGASCALE_NUM_SLICES
                      value: ""
                    resources:
                      requests:
                        google.com/tpu: 4
                      limits:
                        google.com/tpu: 4

    BUCKET_NAME 替换为您的 Cloud Storage 存储桶名称。

    在此清单中:

    • replicas: 2 是作业副本的数量。每个作业代表一个模型服务器。因此,一组 8 个 Pod。
    • parallelism: 8completions: 8 等于每个节点池中的节点数量。
    • 如果有任何 Pod 失败,backoffLimit: 0 必须为零以将作业标记为失败。
    • ports.containerPort: 8471 是用于虚拟机通信的默认端口
    • name: MEGASCALE_NUM_SLICES 会取消设置环境变量,因为 GKE 未运行多切片训练。
  4. 应用清单:

    kubectl apply -f sax-model-server-set.yaml
    
  5. 验证 Saxml 管理服务器和模型服务器 Pod 的状态:

    kubectl get pods
    

    输出类似于以下内容:

    NAME                                              READY   STATUS    RESTARTS   AGE
    sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
    sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-1-sl8w4   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-2-hb4rk   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-3-qv67g   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-4-pzqz6   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-5-nm7mz   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-6-7br2x   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-7-4pw6z   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-0-8mlf5   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-1-h6z6w   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-2-jggtv   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-3-9v8kj   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-4-6vlb2   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-5-h689p   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-6-bgv5k   1/1     Running   0          24m
    sax-model-server-set-sax-model-server-1-7-cd6gv   1/1     Running   0          24m
    

在此示例中,有 16 个模型服务器容器:sax-model-server-set-sax-model-server-0-0-nj4smsax-model-server-set-sax-model-server-1-0-8mlf5 是每个组中的两个主模型服务器。

您的 Saxml 集群有两个模型服务器,分别部署在两个具有 4x8 拓扑的 v5e TPU 切片节点池上。

部署 Saxml HTTP 服务器和负载均衡器

  1. 使用以下预构建映像 HTTP 服务器映像。保存以下 sax-http.yaml 清单:

    apiVersion: apps/v1
    kind: Deployment
    metadata:
      name: sax-http
    spec:
      replicas: 1
      selector:
        matchLabels:
          app: sax-http
      template:
        metadata:
          labels:
            app: sax-http
        spec:
          hostNetwork: false
          serviceAccountName: sax-sa
          containers:
          - name: sax-http
            image: us-docker.pkg.dev/cloud-tpu-images/inference/sax-http:v1.0.0
            ports:
            - containerPort: 8888
            env:
            - name: SAX_ROOT
              value: "gs://BUCKET_NAME/sax-root"
    ---
    apiVersion: v1
    kind: Service
    metadata:
      name: sax-http-lb
    spec:
      selector:
        app: sax-http
      ports:
      - protocol: TCP
        port: 8888
        targetPort: 8888
      type: LoadBalancer

    BUCKET_NAME 替换为您的 Cloud Storage 存储桶名称。

  2. 应用 sax-http.yaml 清单:

    kubectl apply -f sax-http.yaml
    
  3. 等待 HTTP 服务器容器完成创建:

    kubectl get pods
    

    输出类似于以下内容:

    NAME                                              READY   STATUS    RESTARTS   AGE
    sax-admin-server-557c85f488-lnd5d                 1/1     Running   0          35h
    sax-http-65d478d987-6q7zd                         1/1     Running   0          24m
    sax-model-server-set-sax-model-server-0-0-nj4sm   1/1     Running   0          24m
    ...
    
  4. 等待系统为 Service 分配外部 IP 地址:

    kubectl get svc
    

    输出类似于以下内容:

    NAME           TYPE           CLUSTER-IP    EXTERNAL-IP   PORT(S)          AGE
    sax-http-lb    LoadBalancer   10.48.11.80   10.182.0.87   8888:32674/TCP   7m36s
    

使用 Saxml

在 v5e TPU 多主机切片中的 Saxml 上加载、部署和应用模型:

加载模型

  1. 检索 Saxml 的负载均衡器 IP 地址。

    LB_IP=$(kubectl get svc sax-http-lb -o jsonpath='{.status.loadBalancer.ingress[*].ip}')
    PORT="8888"
    
  2. 在两个 v5e TPU 切片节点池中加载 LmCloudSpmd175B 测试模型:

    curl --request POST \
    --header "Content-type: application/json" \
    -s ${LB_IP}:${PORT}/publish --data \
    '{
        "model": "/sax/test/spmd",
        "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "replicas": 2
    }'
    

    测试模型没有经过微调的检查点,权重是随机生成的。模型加载最多可能需要 10 分钟。

    输出类似于以下内容:

    {
        "model": "/sax/test/spmd",
        "path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
        "checkpoint": "None",
        "replicas": 2
    }
    
  3. 检查模型就绪情况:

    kubectl logs sax-model-server-set-sax-model-server-0-0-nj4sm
    

    输出类似于以下内容:

    ...
    loading completed.
    Successfully loaded model for key: /sax/test/spmd
    

    模型已完全加载。

  4. 获取模型的相关信息:

    curl --request GET \
    --header "Content-type: application/json" \
    -s ${LB_IP}:${PORT}/listcell --data \
    '{
        "model": "/sax/test/spmd"
    }'
    

    输出类似于以下内容:

    {
    "model": "/sax/test/spmd",
    "model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test",
    "checkpoint": "None",
    "max_replicas": 2,
    "active_replicas": 2
    }
    

应用模型

应用提示请求:

curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/generate --data \
'{
  "model": "/sax/test/spmd",
  "query": "How many days are in a week?"
}'

以下输出显示了模型响应的示例。此响应可能没有意义,因为测试模型具有随机权重。

取消发布模型

运行以下命令以取消发布模型:

curl --request POST \
--header "Content-type: application/json" \
-s ${LB_IP}:${PORT}/unpublish --data \
'{
    "model": "/sax/test/spmd"
}'

输出类似于以下内容:

{
  "model": "/sax/test/spmd"
}

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

删除已部署的资源

  1. 删除您为本教程创建的集群:

    gcloud container clusters delete saxml --zone ${ZONE}
    
  2. 删除服务账号:

    gcloud iam service-accounts delete sax-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
    
  3. 删除 Cloud Storage 存储桶:

    gcloud storage rm -r gs://${GSBUCKET}
    

后续步骤