使用 GKE 中的多主机 TPU 和 JetStream 和 Pathways 来提供 LLM


本指南介绍了如何跨多个节点使用张量处理单元 (TPU),在 Google Kubernetes Engine (GKE) 上提供先进的大语言模型 (LLM),例如 Llama 3.1 405B

本指南演示了如何使用可移植的开源技术(Kubernetes、JetStreamPathways on CloudLeaderWorkerSet (LWS) API)在 GKE 上部署和提供 AI/机器学习工作负载,并利用 GKE 的精细控制、可伸缩性、弹性、可移植性和成本效益。

背景

大语言模型的规模不断扩大,已无法在单个主机 TPU 切片上运行。对于机器学习推理,您可以使用 Pathways on Cloud 在 GKE 上跨多个互连的 TPU 节点运行大规模多主机推理。在本指南中,您将逐步了解如何预配具有多主机 TPU 切片的 GKE 集群,使用 Pathways on Cloud 二进制文件,通过 MaxText 框架启动 JetStream 服务器,以及发出多主机推理请求。

通过 JetStreamMaxTextPathways 使用 GKE 上的 TPU 应用 LLM,您可以构建一个可用于生产用途的强大服务解决方案,具备托管式 Kubernetes 的所有优势,包括经济高效、可伸缩性和更高的可用性。本部分介绍本教程中使用的关键技术。

TPU 简介

TPU 是 Google 定制开发的应用专用集成电路 (ASIC),用于加速使用 TensorFlowPyTorchJAX 等框架构建的机器学习和 AI 模型。

使用 GKE 中的 TPU 之前,我们建议您完成以下学习路线:

  1. 了解 Cloud TPU 系统架构中的当前 TPU 版本可用性。
  2. 了解 GKE 中的 TPU

本教程介绍如何应用 Llama 3.1-405B 模型。GKE 在多主机 TPU v6e 节点上部署模型,并根据模型要求配置 TPU 拓扑,以低延迟响应提示。

Cloud 上的学习路线

Pathways 是一个适用于加速器的大规模编排层。Pathways 经过精心设计,可用于探索新的系统和机器学习研究理念,同时保持当前模型的出色性能。Pathways 使单个 JAX 客户端进程能够协调一个或多个大型 TPU 切片之间的计算,从而简化跨数百或数千个 TPU 芯片的机器学习计算。

JetStream

JetStream 是由 Google 开发的开源推理服务框架。JetStream 可以在 TPU 和 GPU 上实现高性能、高吞吐量和内存优化的推理。JetStream 提供高级性能优化(包括连续批处理、KV 缓存优化和量化技术),以协助 LLM 部署。JetStream 支持 PyTorch/XLA 和 JAX TPU 服务,从而优化性能。

MaxText

MaxText是一个高性能、可扩缩且适应性强的 JAX LLM 实现,基于如下开源 JAX 仓库构建:FlaxOrbaxOptax。MaxText 的仅解码器 LLM 实现是使用 Python 编写的。它大量利用 XLA 编译器来实现高性能,而无需构建自定义内核。

如需详细了解 MaxText 支持的最新模型和参数大小,请参阅 MaxText 项目仓库

Llama 3.1 405B

Llama 3.1 405B 是由 Meta 提供的大语言模型,专为各种自然语言处理任务(包括文本生成、翻译和问答)而设计。GKE 提供所需的基础设施,以支持这种规模的模型的分布式训练和服务需求。

如需了解详情,请参阅 Llama 文档

架构

本部分介绍本教程中使用的 GKE 架构。该架构包括一个 GKE Standard 集群,该集群用于预配 TPU 并托管 JetStream 和 Pathways 组件以部署和应用模型。

下图展示了此架构的组件:

具有多主机 TPU 节点池(其中包含 JetStream 和 Pathways 组件)的 GKE 集群的架构。

此架构包括以下组件:

  • GKE Standard 区域级集群。
  • 一个多主机 TPU 切片节点池,用于托管 JetStream 部署和 Pathways 组件。
  • Pathways resource manager 管理加速器资源,并协调用户作业的加速器分配。
  • Pathways clientPathways resource manager 协同工作,以确定编译后的程序放置在何处以供执行。
  • Pathways worker 在加速器机器上运行并执行计算,然后通过 IFRT 代理服务器将数据发送回工作负载。
  • IFRT proxy client 实现了 OSS 临时框架运行时 (IFRT) API,并充当工作负载与 Pathways 组件之间的通信桥梁。
  • IFRT proxy serverIFRT proxy client 接收请求并将其转发给 Pathways client,从而分配工作。
  • JetStream-Pathways 容器提供了一个基于 JAX 的推理服务器,该服务器接收推理请求并将其执行过程委托给 Pathways workers
  • Service 组件将入站流量分布到所有 JetStream HTTP 副本。
  • JetStream HTTP 是一个 HTTP 服务器,它接受封装容器形式的 JetStream 所需格式的请求并将其发送到 JetStream 的 GRPC 客户端

准备工作

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  • Make sure that billing is enabled for your Google Cloud project.

  • Enable the required API.

    Enable the API

  • Make sure that you have the following role or roles on the project: roles/container.admin, roles/iam.serviceAccountAdmin, roles/resourcemanager.projectIamAdmin

    Check for the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. In the Principal column, find all rows that identify you or a group that you're included in. To learn which groups you're included in, contact your administrator.

    4. For all rows that specify or include you, check the Role column to see whether the list of roles includes the required roles.

    Grant the roles

    1. In the Google Cloud console, go to the IAM page.

      前往 IAM
    2. 选择项目。
    3. 点击 授予访问权限
    4. 新的主账号字段中,输入您的用户标识符。 这通常是 Google 账号的电子邮件地址。

    5. 选择角色列表中,选择一个角色。
    6. 如需授予其他角色,请点击 添加其他角色,然后添加其他各个角色。
    7. 点击 Save(保存)。
  • 确保您有足够的配额用于 16 个 TPU v6e PodSlice Lite 芯片。在本教程中,您将使用按需实例
  • 确保您的 Google Cloud 项目已列入 Pathways 的许可名单。

获取对模型的访问权限

如需获取 Meta Llama 3.1-405B 检查点以部署到 GKE,请按以下步骤操作:

  1. 签署许可同意协议。
  2. 访问 Meta Llama 下载页面
  3. 查看并接受模型条款及条件,并获取下载模型所需的网址。
  4. 如需下载模型检查点,请找到相应模型的模型 ID。如需查看支持的模型及其 ID 的列表,请参阅 llama CLI 文档。例如,对于 Llama 3.1-405B 模型,请使用 Llama 3.1-405B-Instruct:bf16-mp16

准备环境

在本教程中,您将使用 Cloud Shell 来管理Google Cloud上托管的资源。Cloud Shell 中预安装了本教程所需的软件,包括 kubectl gcloud CLI

如需使用 Cloud Shell 设置您的环境,请按照以下步骤操作:

  1. 在 Google Cloud 控制台中,点击 Google Cloud 控制台中的 Cloud Shell 激活图标 激活 Cloud Shell 以启动 Cloud Shell 会话。此操作会在 Google Cloud 控制台的底部窗格中启动会话。

  2. 设置默认环境变量:

    gcloud config set project PROJECT_ID
    gcloud config set billing/quota_project PROJECT_ID
    export PROJECT_ID=$(gcloud config get project)
    export CLUSTER_NAME=CLUSTER_NAME
    export BUCKET_NAME=BUCKET_NAME
    export CONTROL_PLANE_LOCATION=CONTROL_PLANE_LOCATION
    export NODE_LOCATION=NODE_LOCATION
    export CLUSTER_VERSION=CLUSTER_VERSION
    export MACHINE_TYPE=ct6e-standard-4t
    export TPU_TYPE=v6e
    export TOPOLOGY=4x4
    export WORKERS_PER_SLICE=4
    

    替换以下值:

    • PROJECT_ID:您的 Google Cloud 项目 ID
    • CLUSTER_NAME:GKE 集群的名称。
    • BUCKET_NAME:Cloud Storage 存储桶的名称。您无需指定 gs:// 前缀。
    • CONTROL_PLANE_LOCATION:GKE 集群、Cloud Storage 存储分区和 TPU 节点所在的 Compute Engine 区域。该区域包含可以使用 TPU v6e 机器类型的可用区(例如 us-east1us-east5europe-west4asia-northeast1us-south1)。
    • NODE_LOCATION可用的 TPU 资源所在的可用区(例如 us-east1-d)。
    • CLUSTER_VERSION:GKE 版本,必须支持您要使用的机器类型。 请注意,默认 GKE 版本可能无法为您的目标 TPU 提供可用性。 如需查看 TPU 机器类型可用的最低 GKE 版本列表,请参阅 GKE 中的 TPU 可用性
    • MACHINE_TYPE:v6e 机器类型。
    • TPU_TYPE:用于命名节点池 (v6e) 的前缀。
    • TOPOLOGY:TPU v6e 拓扑。
    • WORKERS_PER_SLICE:每个节点池或 TPU 切片的节点数。

创建和配置 Google Cloud 资源

如需创建所需的资源,请按照以下说明操作:

创建 GKE 集群

  1. 创建区域级 GKE Standard 集群:

    gcloud container clusters create CLUSTER_NAME \
        --project=PROJECT_ID \
        --cluster-version=CLUSTER_VERSION \
        --location=CONTROL_PLANE_LOCATION \
        --scopes=cloud-platform \
        --machine-type=n2-standard-32
    

    集群创建可能需要几分钟的时间。

    CLUSTER_VERSION 替换为适当的集群版本

  2. 创建一个 TPU v6e 节点池,该节点池具有 4x4 拓扑,且每个节点有 4 个 TPU 芯片:

    gcloud container node-pools create multihost-np \
    --project=PROJECT_ID \
    --location=CONTROL_PLANE_LOCATION \
    --node-locations=NODE_LOCATION \
    --cluster=CLUSTER_NAME \
    --machine-type=MACHINE_TYPE \
    --num-nodes=WORKERS_PER_SLICE \
    --tpu-topology=TOPOLOGY \
    --scopes cloud-platform \
    --placement-type=COMPACT \
    --workload-metadata=GCE_METADATA
    

为存储对象访问配置服务账号

配置 Kubernetes 服务账号以充当 IAM 服务账号。

  1. 为您的应用创建 IAM 服务账号:

    gcloud iam service-accounts create jetstream-pathways
    
  2. 为您的 IAM 服务账号添加 IAM 政策绑定,以便管理 Cloud Storage。这是为了使您的 IAM 服务账号能够访问将存储检查点的存储桶:

    gcloud projects add-iam-policy-binding ${PROJECT} \
      --member "serviceAccount:jetstream-pathways@${PROJECT}.iam.gserviceaccount.com" \
      --role roles/storage.objectUser
    
    gcloud projects add-iam-policy-binding ${PROJECT} \
      --member "serviceAccount:jetstream-pathways@${PROJECT}.iam.gserviceaccount.com" \
      --role roles/storage.insightsCollectorService
    
  3. 使用 IAM 服务账号的电子邮件地址为 Kubernetes 服务账号添加注解。

    kubectl annotate serviceaccount default \
    iam.gke.io/gcp-service-account=jetstream-pathways@${PROJECT}.iam.gserviceaccount.com
    

配置 Docker 以向 Artifact Registry 进行身份验证

配置 Docker 以向 Artifact Registry 进行身份验证,以便拉取已列入许可名单的 Pathways 映像:

gcloud auth login
gcloud auth configure-docker

检查点转换

如需将 Meta Llama 3.1-405B 检查点转换为与 MaxText 兼容的 int8 推理检查点,请完成使用 Llama3.1-405B 进行检查点转换中的步骤。您的部署使用带有 load_parameters_path 标志的检查点。

创建 Cloud Storage 存储桶以存储 Pathways 临时文件

创建一个 Cloud Storage 存储桶来存储您的 Pathways 临时文件,例如编译缓存:

export PATHWAYS_BUCKET=PATHWAYS_BUCKET
gcloud storage buckets create gs://$PATHWAYS_BUCKET

部署 JetStream-MaxText 和 Pathways

部署 JetStream-MaxText 和 Pathways 模型服务器。

连接到 GKE 集群

gcloud container clusters get-credentials "${CLUSTER}" --project "${PROJECT}" --location "${ZONE}"

部署 LeaderWorkerSet (LWS) API

LWS 是一种自定义资源,旨在部署和管理有状态的分布式应用,尤其是那些采用领导者-工作器架构的应用。它尤其适合 AI/ML 工作负载,在这种工作负载中,大型模型会被分片并跨多个节点上的多个设备提供服务。

VERSION=v0.6.1
kubectl apply --server-side -f https://github.com/kubernetes-sigs/lws/releases/download/$VERSION/manifests.yaml

等待 LeaderWorkerSet 控制器完全可用:

kubectl wait deploy/lws-controller-manager -n lws-system --for=condition=available --timeout=5m

输出应类似如下所示:

deployment.apps/lws-controller-manager condition met

验证 LeaderWorkerSet 控制器是否在 lws-system 命名空间中运行:

kubectl get pod -n lws-system

输出应类似如下所示:

NAME                          READY   STATUS    RESTARTS    AGE
lws-controller-manager-abcd   1/1     Running   0           40s
lws-controller-manager-efgh   1/1     Running   0           40s

部署工作负载清单

  1. 将以下清单保存为 jetstream-pathways-llama-3-1-405b-4x4.yaml

    apiVersion: leaderworkerset.x-k8s.io/v1
    kind: LeaderWorkerSet
    metadata:
      name: jetstream-pathways
      annotations:
        leaderworkerset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      replicas: 1
      leaderWorkerTemplate:
        leaderTemplate:
          metadata:
            labels:
              app: jetstream-pathways
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: 4x4
            tolerations:
            - key: "google.com/tpu"
              operator: "Exists"
              effect: "NoSchedule"
            containers:
            - name: pathways-proxy
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.5.3
              args:
              imagePullPolicy: Always
              ports:
              - containerPort: 38681
            - name: pathways-rm
              env:
              - name: HOST_ADDRESS
                value: "$(LWS_LEADER_ADDRESS)"
              - name: TPU_SKIP_MDS_QUERY
                value: "true"
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.5.3
              args:
              - --server_port=38677
              - --gcs_scratch_location=PATHWAYS_BUCKET
              - --node_type=resource_manager
              - --instance_count=1
              - --instance_type=tpuv6e:4x4
              imagePullPolicy: Always
              ports:
              - containerPort: 38677
            - name: jax-tpu
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
              env:
              - name: LOG_LEVEL
                value: "INFO"
              args:
              - MaxText/configs/v5e/inference/llama3_405b_v5e-64.yml
              - model_name=llama3.1-405b
              - load_parameters_path=CHECKPOINT_PATH
              - max_prefill_predict_length=1024
              - max_target_length=2048
              - async_checkpointing=false
              - steps=1
              - ici_fsdp_parallelism=1
              - ici_autoregressive_parallelism=2
              - ici_tensor_parallelism=8
              - scan_layers=false
              - weight_dtype=bfloat16
              - per_device_batch_size=6
              - enable_single_controller=true
              - quantization=int8
              - quantize_kvcache=true
              - checkpoint_is_quantized=true
              - enable_model_warmup=true
              imagePullPolicy: Always
              ports:
              - containerPort: 9000
              startupProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 1
                initialDelaySeconds: 600
                failureThreshold: 10000
              livenessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 10
              readinessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 10
            - name: jetstream-http
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
              imagePullPolicy: Always
              ports:
              - containerPort: 8000
        size: 5
        workerTemplate:
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: 4x4
            tolerations:
            - key: "google.com/tpu"
              operator: "Exists"
              effect: "NoSchedule"
            containers:
            - name: worker
              args:
              - --server_port=38679
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --gcs_scratch_location=PATHWAYS_BUCKET
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.5.3
              imagePullPolicy: Always
              ports:
              - containerPort: 38679
              resources:
                limits:
                  google.com/tpu: "4"
    --- 
    apiVersion: v1
    kind: Service
    metadata:
      name: jetstream-svc
    spec:
      selector:
        app: jetstream-pathways
      ports:
      - protocol: TCP
        name: jetstream-http
        port: 8000
        targetPort: 8000
  2. load_parameters_path 字段的值设置为在检查点转换过程中生成的检查点路径。

    • 对于 bf16 检查点,路径应类似于 gs://OUTPUT_BUCKET_DIRECTORY/bf16/unscanned/checkpoints/0/items
    • 对于 int8 检查点,它应类似于 gs://OUTPUT_BUCKET_DIRECTORY/int8

    gcs_scratch_location 字段的值设置为您之前创建的 Pathways 存储桶。

    perl -pi -e 's|CHECKPOINT_PATH|gs://OUTPUT_BUCKET_DIRECTORY/int8|g' jetstream-pathways-llama-3-1-405b-4x4.yaml
    perl -pi -e 's|PATHWAYS_BUCKET|gs://PATHWAYS_BUCKET|g' jetstream-pathways-llama-3-1-405b-4x4.yaml
    

应用 Deployment 清单

应用清单以部署服务器:

kubectl apply -f jetstream-pathways-llama-3-1-405b-4x4.yaml

模型服务器应会启动。

验证模型服务器启动

405B 模型可能需要大约 10 到 20 分钟才能恢复检查点。如果您启用了 enable_model_warmup 标志,则在模型预热期间可能还需要等待额外的时间。

kubectl logs -f jetstream-pathways-0 -c jax-tpu

输出类似于以下内容:

2025-03-02 02:15:07,682 - JetstreamLogger - INFO - Initializing the driver with 1 prefill engines and 1 generate engines in interleaved mode
2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up prefill thread 0.
2025-03-02 02:15:07,683 - JetstreamLogger - INFO - Spinning up transfer thread 0.
2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up generate thread 0.
2025-03-02 02:15:07,684 - JetstreamLogger - INFO - Spinning up detokenize thread 0.
2025-03-02 02:15:07,685 - JetstreamLogger - INFO - Driver initialized.
...
...
...
INFO:     Started server process [7]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:9999 (Press CTRL+C to quit)

提供 Llama 3.1-405b

如需提供 Llama 3.1-405b 模型,请设置端口转发:

kubectl port-forward svc/jetstream-svc 8000:8000

借助端口转发,您可以从集群外部访问服务。您可以通过 GKE 的 ClusterIP Service 访问 JetStream-Pathways Deployment。只能从集群内部访问 ClusterIP Service。

与模型交互

在新终端中,运行以下命令:

curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8000/generate \
--data \
'{
    "prompt": "What are the top 5 programming languages",
    "max_tokens": 200
}'

由于模型预热,初始请求可能需要几秒钟才能完成。输出应类似如下所示:

{
    "response": " for web development?\nThe top 5 programming languages for web development are:\n1. **JavaScript**: JavaScript is the most popular language for web development, used by over 90% of websites for client-side scripting. It's also popular for server-side programming with technologies like Node.js.\n2. **HTML/CSS**: HTML (Hypertext Markup Language) and CSS (Cascading Style Sheets) are not programming languages, but are essential for building websites. HTML is used for structuring content, while CSS is used for styling and layout.\n3. **Python**: Python is a popular language for web development, especially with frameworks like Django and Flask. It's known for its simplicity, flexibility, and large community of developers.\n4. **Java**: Java is a popular language for building enterprise-level web applications, especially with frameworks like Spring and Hibernate. It's known for its platform independence, strong security features, and large community of developers.\n5. **PHP**: PHP is a mature language for web"
}

您已成功完成以下操作:

  1. 使用 TPU 在 GKE 上部署了 JetStream 模型服务器,其中包含 MaxText 和 Pathways。
  2. gs://BUCKET_NAME 中创建了 Llama 3.1-405B int8 检查点。
  3. 应用了模型并与之互动。

分离式投放

分离式推理是一种在推理 LLM 时将预填充和解码阶段拆分到不同主机上的技术。此方法可优化资源利用率,从而提高吞吐量并缩短延迟时间。

  • 预填充:对输入提示进行前向传递,以初始化键值对缓存。

  • 解码:一种逐步生成输出令牌的程序,每步生成一个令牌,每次迭代生成一个 KV 缓存值。

  1. 设置默认环境变量:

    export NODE_POOL_NAME=dis-v6e-8
    export NODE_POOL_SIZE=2
    export MACHINE_TYPE=ct6e-standard-4t
    export TOPOLOGY=2x4
    export WORKERS_PER_SLICE=2
    
  2. 创建两个使用 v6e-8 节点的节点池:

    for i in $(seq 1 NODE_POOL_SIZE); do
      gcloud container node-pools create NODE_POOL_NAME-${i}-np \
      --project=PROJECT \
      --cluster=CLUSTER_NAME \
      --location=CONTROL_PLANE_LOCATION \
      --node-locations=NODE_LOCATION \
      --machine-type=MACHINE_TYPE \
      --num-nodes=WORKERS_PER_SLICE \
      --tpu-topology=TOPOLOGY \
      --scopes=cloud-platform \
      --workload-metadata=GCE_METADATA
    done
    

检查点转换

如需将 Meta Llama 2-70B 检查点转换为与 MaxText 兼容的 int8 推理检查点,请完成使用 Llama2-70B 进行检查点转换中的步骤。在确认 Meta 条款及条件时,选择 Llama2-70B 作为模型。您的部署使用带有 load_parameters_path 标志的检查点。

checkpoint-job.yaml 文件中替换以下参数:

- --meta_url=META_URL
- --model_name=llama-2
- --model_path=Llama-2-70b-chat
- --output_directory=gs://BUCKET_NAME/maxtext/llama-2-70b

该检查点将通过 load_parameters_path 标志用于您的部署中。

部署采用分离式服务的 JetStream Pathways

  1. 将以下清单保存为 jetstream-pathways-disagg-llama-2-70b-2-2x4.yaml

    apiVersion: leaderworkerset.x-k8s.io/v1
    kind: LeaderWorkerSet
    metadata:
      name: jetstream-pathways
      annotations:
        leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      replicas: 1
      leaderWorkerTemplate:
        subGroupPolicy:
          subGroupSize: 2
        leaderTemplate:
          metadata:
            labels:
              app: jetstream-pathways
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: 2x4
            tolerations:
            - key: "google.com/tpu"
              operator: "Exists"
              effect: "NoSchedule"
            containers:
            - name: pathways-proxy
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:jax-0.5.3
              args:
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --server_port=38681
              - --gcs_scratch_location=gs://cloud-pathways-staging/tmp
              - --xla_jf_auto_cross_replica_sharding=false
              - --xla_tpu_enable_windowed_einsum_for_reduce_scatter=false
              - --xla_tpu_enable_windowed_einsum_for_all_gather=false
              - --xla_tpu_prefer_latch_optimized_rhs_layouts=true
              - --xla_tpu_enable_experimental_fusion_cost_model=false
              - --xla_tpu_dot_dot_fusion_duplicated=false
              - --xla_tpu_dot_dot_fusion=true
              - --xla_jf_conv_input_fusion=true
              - --xla_jf_conv_output_fusion=true
              - --xla_tpu_rwb_fusion=false
              - --xla_tpu_copy_fusion_pad_unpad_ratio=0
              - --xla_tpu_licm_size_inflation_ratio=1
              - --xla_tpu_copy_elision_analysis_allowance=150000
              - --xla_tpu_copy_insertion_use_region_analysis_limit=10000
              - --xla_tpu_order_dot_after_layout=true
              - --xla_jf_rematerialization_percent_shared_memory_limit=100
              - --xla_tpu_use_repeated_instance_for_preferred_prefetch_time=true
              - --xla_tpu_enforce_prefetch_fifo_order=false
              - --xla_tpu_prefetch_interval_picker_size_override=6000000
              - --xla_tpu_async_copy_bandwidth_scaling_factor=1
              - --xla_tpu_nd_short_transfer_max_chunks=-1
              - --xla_tpu_enable_aggressive_broadcast_priority_update=true
              - --xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers=SQRT
              - --xla_tpu_memory_bound_loop_optimizer_options=enabled:true
              - --xla_tpu_enable_copy_fusion=true
              - --xla_tpu_enable_cross_program_prefetch_freeing=false
              - --xla_tpu_enable_dot_strength_reduction=true
              - --xla_tpu_layout_use_dot_grouping=false
              - --xla_tpu_msa_inefficient_use_to_copy_ratio=0.5
              - --xla_tpu_reduce_loop_fusion_dup_with_unfusable_user=false
              - --xla_tpu_vector_load_fusion_window=1024
              - --xla_tpu_vector_store_fusion_window=256
              - --xla_jf_conv_reshape_fusion=false
              - --xla_tpu_input_conv_multi_users=false
              - --xla_tpu_enable_multi_level_input_dot_dot_fusion=false
              - --xla_tpu_enable_multi_level_output_dot_dot_fusion=false
              - --xla_tpu_dot_dot_fusion_separable_convs_only=false
              - --xla_tpu_enable_multi_level_nested_loop_fusion=true
              - --xla_tpu_nested_dot_fusion=true
              - --xla_tpu_enable_multi_level_nested_dot_fusion=false
              - --xla_jf_enable_multi_output_fusion=true
              - --xla_tpu_use_lp_llo_scheduler_for_dot_dot_fusions=false
              - --xla_tpu_enable_flash_attention=true
              imagePullPolicy: Always
              ports:
              - containerPort: 38681
            - name: pathways-rm
              env:       
              - name: HOST_ADDRESS
                value: "$(LWS_LEADER_ADDRESS)"
              - name: TPU_SKIP_MDS_QUERY
                value: "true"
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.5.3
              args:
              - --server_port=38677
              - --gcs_scratch_location=PATHWAYS_BUCKET
              - --node_type=resource_manager
              - --instance_count=2
              - --instance_type=tpuv6e:2x4
              imagePullPolicy: Always
              ports:
              - containerPort: 38677
            - name: jax-tpu
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pathways:v0.2.0
              args:
              - MaxText/configs/base.yml
              - tokenizer_path=assets/tokenizer.llama2
              - load_parameters_path=CHECKPOINT_PATH
              - max_prefill_predict_length=1024
              - max_target_length=2048
              - model_name=llama2-70b
              - ici_fsdp_parallelism=1
              - ici_autoregressive_parallelism=1
              - ici_tensor_parallelism=-1
              - scan_layers=false
              - weight_dtype=bfloat16
              - per_device_batch_size=27
              - checkpoint_is_quantized=true 
              - quantization=int8
              - quantize_kvcache=true
              - compute_axis_order=0,2,1,3
              - ar_cache_axis_order=0,2,1,3
              - stack_prefill_result_cache=True
              - inference_server=ExperimentalMaxtextDisaggregatedServer_8
              - inference_benchmark_test=True
              - enable_model_warmup=True
              env:
              - name: LOG_LEVEL
                value: "INFO"
              imagePullPolicy: Always
              securityContext:
                capabilities:
                  add: ["SYS_PTRACE", "NET_ADMIN", "SYS_TIME"]
              ports: 
              - containerPort: 9000
              startupProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 1
                initialDelaySeconds: 240
                failureThreshold: 10000
              livenessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
              readinessProbe:
                httpGet:
                  path: /healthcheck
                  port: 8000
                  scheme: HTTP
                periodSeconds: 60
                failureThreshold: 100
            - name: jetstream-http
              image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.3
              imagePullPolicy: Always
              ports:
              - containerPort: 8000
        size: 5
        workerTemplate:
          spec:
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
              cloud.google.com/gke-tpu-topology: 2x4
            containers:
            - name: worker
              args:
              - --server_port=38679
              - --resource_manager_address=$(LWS_LEADER_ADDRESS):38677
              - --gcs_scratch_location=PATHWAYS_BUCKET
              image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:jax-0.5.3
              imagePullPolicy: Always
              ports:
              - containerPort: 38679
              resources:
                limits:
                  google.com/tpu: "4"
    --- 
    apiVersion: v1
    kind: Service
    metadata:
      name: jetstream-svc
    spec:
      selector:
        app: jetstream-pathways
      ports:
      - protocol: TCP
        name: jetstream-http
        port: 8000
        targetPort: 8000
  2. load_parameters_path 字段的值设置为在检查点转换过程中生成的检查点路径。

    • 对于 bf16 检查点,路径应类似于 gs://OUTPUT_BUCKET_DIRECTORY/bf16/unscanned/checkpoints/0/items
    • 对于 int8 检查点,它应类似于 gs://OUTPUT_BUCKET_DIRECTORY/int8

    gcs_scratch_location 字段的值设置为您之前创建的 Pathways 存储桶。

    perl -pi -e 's|CHECKPOINT_PATH|BUCKET_NAME/maxtext/llama-2-70b/int8|g' jetstream-pathways-disagg-llama-2-70b-2-2x4.yaml
    perl -pi -e 's|PATHWAYS_BUCKET|gs://PATHWAYS_BUCKET|g' jetstream-pathways-disagg-llama-2-70b-2-2x4.yaml
    
  3. 应用清单:

    kubectl apply -f jetstream-pathways-disagg-llama-2-70b-2-2x4.yaml
    

    模型服务器可能需要一些时间才能恢复检查点,具体取决于检查点的大小。一个 70B 模型可能需要大约 8 分钟来恢复检查点,包括模型预热。您可以进一步观察日志,通过验证模型服务器启动来确定就绪点,并通过设置端口转发来部署模型,以便与模型互动

您已成功完成以下操作:

  1. 使用 TPU 和分离式服务在 GKE 上部署了 JetStream 模型服务器,其中包含 MaxText 和 Pathways。
  2. gs://BUCKET_NAME 中创建了 Llama 2-70B int8 检查点。
  3. 应用了模型并与之互动。

问题排查

  • 如果您收到 Empty reply from server 消息,则容器可能尚未完成模型数据下载。再次检查 Pod 的日志中是否包含 Connected 消息,该消息表明模型已准备好进行应用。
  • 如果您看到 Connection refused 消息,请验证您的端口转发处于活跃状态

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

删除已部署的资源

为避免因您在本指南中创建的资源导致您的 Google Cloud 账号产生费用,请运行以下命令并按照提示操作:

gcloud container clusters delete CLUSTER_NAME --location=CONTROL_PLANE_LOCATION

gcloud iam service-accounts delete jetstream-pathways@PROJECT_ID.iam.gserviceaccount.com

gcloud storage rm --recursive gs://BUCKET_NAME

后续步骤