使用 JAX、Ray Train 和 GKE 上的 TPU Trillium 训练 LLM

本教程介绍了如何使用 MaxTextRay Train 和 TPU 在 Google Kubernetes Engine (GKE) 上训练 Llama 3 8B 大语言模型 (LLM)。

本教程提供了一个完整的端到端演示,从配置必要的云基础架构到在多主机 TPU 上提交并成功运行训练工作负载。

本教程适用于希望了解如何在分布式多主机 TPU 切片上训练大型模型的平台管理员和运维人员,以及数据和 AI 专家。

背景

GKE、KubeRay、MaxText 和 TPU 的组合为大规模模型训练提供了一个强大且可扩缩的平台。本部分介绍本指南中使用的关键技术。

JAX

JAX 是一个面向加速器的数组计算和程序转换 Python 库,专为高性能数值计算和大规模机器学习而设计。

JAX 提供了一个可扩展的系统,用于转换 jax.gradjax.jitjax.vmap 等数值函数,利用 XLA 编译器创建高度优化的代码,可在 GPU 和 TPU 等加速器上高效扩展。JAX 的核心优势在于其可组合性,这使得用户能够组合这些转换来构建复杂的、高性能的数值程序,以进行分布式执行。

MaxText

MaxText 是一种高性能的开源大语言模型 (LLM),旨在实现可伸缩性和可自定义性。MaxText 基于 JAX 构建,并经过优化,可在 Cloud TPU 和 GPU 上高效运行。

TPU

张量处理单元 (TPU) 是 Google 专为优化机器学习工作负载而定制设计的加速器。与通用 CPU 或并行处理 GPU 不同,TPU 专门针对深度学习基础中的大规模矩阵和张量计算进行了高度优化,因此能够高效完成此特定任务。TPU 的主要优势在于大规模性能。

本教程使用 TPU Trillium,这是第六代 TPU。 如需了解详情,请参阅使用 TPU Trillium 的优势

KubeRay

KubeRay 是一个 Kubernetes 操作器,可提供一种在 Kubernetes 上部署、管理和监控 Ray 应用的统一方式。KubeRay 操作器通过 Ray on GKE 插件进行安装和管理,这是在 GKE 上部署和管理 Ray 集群的推荐方法。

目标

本教程介绍了如何执行以下操作:

  1. 设置具有多主机 TPU 节点池的 GKE 集群。
  2. 配置 KubeRay 以管理分布式训练环境。
  3. 构建包含 MaxText、Ray 和 JAX 依赖项的自定义 Docker 映像。
  4. 创建一个 Python 训练脚本,该脚本使用 Ray Train 的 JaxTrainer 在 TPU 切片中编排 MaxText 训练循环。
  5. 定义 RayCluster 自定义资源,以预配具有必要 TPU 资源的主节点和工作器节点。
  6. 将训练作业提交给 RayCluster 并监控其进度。
  7. 使用 Cloud Storage 存储模型检查点。

准备工作

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • Install the Google Cloud CLI.

  • 如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI

  • 如需初始化 gcloud CLI,请运行以下命令:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Install the Google Cloud CLI.

  • 如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI

  • 如需初始化 gcloud CLI,请运行以下命令:

    gcloud init
  • Create or select a Google Cloud project.

    Roles required to select or create a project

    • Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project: To create a project, you need the Project Creator (roles/resourcemanager.projectCreator), which contains the resourcemanager.projects.create permission. Learn how to grant roles.
    • Create a Google Cloud project:

      gcloud projects create PROJECT_ID

      Replace PROJECT_ID with a name for the Google Cloud project you are creating.

    • Select the Google Cloud project that you created:

      gcloud config set project PROJECT_ID

      Replace PROJECT_ID with your Google Cloud project name.

  • Verify that billing is enabled for your Google Cloud project.

  • Enable the required API:

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role (roles/serviceusage.serviceUsageAdmin), which contains the serviceusage.services.enable permission. Learn how to grant roles.

    gcloud services enable container.googleapis.com
  • Grant roles to your user account. Run the following command once for each of the following IAM roles: roles/container.admin, roles/iam.serviceAccountAdmin

    gcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE

    Replace the following:

    • PROJECT_ID: Your project ID.
    • USER_IDENTIFIER: The identifier for your user account. For example, myemail@example.com.
    • ROLE: The IAM role that you grant to your user account.
  • 由于本教程使用 TPU Trillium (v6e),请选择可用的区域或可用区。如需了解详情,请参阅 Cloud TPU 配额

准备环境

在本教程中,您将使用 Cloud ShellgcloudCloud Shellhelm 预安装了本教程中使用的 kubectl、 和 命令行工具。

  1. 前往 Google Cloud 控制台

  2. 在 Google Cloud 控制台窗口顶部,点击激活 Cloud Shell 激活 Shell 按钮 按钮。

    一个 Cloud Shell 会话随即会在Google Cloud 控制台中的新框架内打开,并显示命令行提示符。

  3. 创建并激活 Python 虚拟环境:

    python3 -m venv ray-env
    source ray-env/bin/activate
    
  4. 安装 Ray CLI 和其他依赖项:

    pip install "ray[default]==2.49.1"
    
  5. 设置以下环境变量:

    export PROJECT_ID=$(gcloud config get project)
    export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)")
    export GS_BUCKET=GS_BUCKET
    export KSA_NAME=KSA_NAME
    export NAMESPACE=default
    export CLUSTER_NAME=CLUSTER_NAME
    export REGION=REGION
    export ZONE=ZONE
    export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY
    

    替换以下内容:

    • GS_BUCKET:Cloud Storage 存储桶的名称。
    • KSA_NAME:Kubernetes ServiceAccount 的名称。
    • CLUSTER_NAME:新集群的名称。
    • REGION:您的 TPU Trillium 容量可用的区域。
    • ZONE:您的 TPU Trillium 容量可用的可用区。如需了解详情,请参阅 GKE 中的 TPU 可用性
    • ARTIFACT_REGISTRY:Artifact Registry 代码库的名称。

创建 GKE 集群

您可以在 GKE Autopilot 或 Standard 集群中的 TPU 上配置 KubeRay。我们建议您使用 Autopilot 集群获得全托管式 Kubernetes 体验。如需选择最适合您的工作负载的 GKE 操作模式,请参阅GKE 操作模式简介

Autopilot

  1. 在 Cloud Shell 中,运行以下命令:

    gcloud container clusters create-auto $CLUSTER_NAME \
        --enable-ray-operator \
        --machine-type=n1-standard-16 \
        --location=$REGION
    
  2. 如需与集群通信,请配置 kubectl

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=$ZONE
    

标准

  1. 在 Cloud Shell 中,运行以下命令以创建启用 Ray operator 插件的 Standard 集群:

    gcloud container clusters create $CLUSTER_NAME \
        --addons=RayOperator \
        --addons GcsFuseCsiDriver \
        --machine-type=n1-standard-16 \
        --workload-pool=$PROJECT_ID.svc.id.goog \
        --location=$ZONE
    

    此命令还会启用 GcsFuseCsiDriver,从而允许 Pod 将 Cloud Storage 存储分区作为本地文件系统进行装载。集群创建可能需要几分钟的时间。

  2. 如需与集群通信,请配置 kubectl

    gcloud container clusters get-credentials CLUSTER_NAME \
        --location=LOCATION
    
  3. 创建多主机 TPU 切片节点池:

    gcloud container node-pools create v6e-16 \
        --location=$ZONE \
        --cluster=$CLUSTER_NAME \
        --machine-type=ct6e-standard-4t \
        --threads-per-core=1 \
        --tpu-topology=4x4 \
        --num-nodes=4
    

GKE 会预配一个由四个 TPU Trillium (v6e) 虚拟机组成的节点池,这些虚拟机共同配置为具有 4x4 拓扑的多主机 TPU 切片,可用于分布式训练工作负载。

启用了 Ray 操作器的 GKE 集群会自动在集群中安装 KubeRay 和 KubeRay TPU webhook

配置 Cloud Storage 存储分区和服务账号

  1. 创建一个 Cloud Storage 存储分区,用于在多主机 TPU 节点之间共享检查点。

    gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}
    
  2. 如需启用对 Cloud Storage 存储分区的访问权限,请创建 Kubernetes ServiceAccount:

    kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}
    
  3. 如需启用对 Cloud Storage 存储分区的访问权限,请向服务账号添加所需的 IAM 政策绑定:

    gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \
        --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \
        --role "roles/storage.objectUser"
    

创建训练脚本

以下脚本使用 Ray Train 的 JaxTrainer 运行分布式 MaxText 训练作业。该脚本可为多主机 TPU 切片节点池配置训练环境,并在每个工作器节点上运行 MaxText 训练作业。train_loop_per_worker 函数封装了 MaxText 主要入口点,并使用 Ray 的分布式调度程序在多主机 TPU 切片上执行 MaxText 训练器。

  1. 将以下 Python 脚本保存为 maxtext_ray_trainer.py

    import os
    from absl import app
    import logging
    from typing import Sequence
    import ray
    from ray.train.v2.api.config import ScalingConfig, RunConfig
    from ray.train.v2.jax import JaxTrainer
    
    def train_loop_per_worker(config):
        from MaxText.train import main as maxtext_main
    
        argv = config["argv"]
        maxtext_main(argv)
    
    def main(argv: Sequence[str]):
        trainer = JaxTrainer(
            train_loop_per_worker=train_loop_per_worker,
            train_loop_config={"argv": argv},
            scaling_config=ScalingConfig(
                use_tpu=True,
                num_workers=4,
                topology="4x4",
                accelerator_type="TPU-V6E",
                resources_per_worker={"TPU": 4},
                placement_strategy="SPREAD",
            ),
            run_config=RunConfig(
                name="maxtext_jaxtrainer",
                worker_runtime_env={
                    "env_vars": {
                        "JAX_PLATFORMS": "tpu",
                        "ENABLE_PJRT_COMPATIBILITY": "true",
                        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
                        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
                        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",
                    }
                },
            ),
        )
        result = trainer.fit()
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        app.run(main)
  2. 如需托管自定义映像,请创建 Artifact Registry 制品库:

    gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \
        --repository-format=docker --location=${REGION} && \
    gcloud auth configure-docker ${REGION}-docker.pkg.dev
    
  3. 如需构建包含用于训练的 Ray 和 MaxText 依赖项的映像,请创建 Dockerfile

    # Start from a Ray base image which includes JaxTrainer API.
    # Maxtext with TPU requires Python 3.12.
    FROM rayproject/ray:2.49.1-py312
    
    USER root
    RUN groupadd -r ray 2>/dev/null || true && usermod -g ray ray
    
    RUN sudo apt-get update -y \
      && sudo apt-get install --no-install-recommends -y git \
      && sudo rm -rf /var/lib/apt/lists/*
    
    WORKDIR /app
    
    # Clone the Maxtext repo and build from source, installing TPU dependencies.
    RUN git clone https://github.com/AI-Hypercomputer/maxtext.git
    
    RUN pip install --no-cache-dir uv
    
    RUN cd maxtext && \
        uv pip install --no-cache --system -e .[tpu] --resolution=lowest && \
        install_maxtext_github_deps
    
    # Copy the Ray Maxtext trainer to run on the remote container.
    COPY maxtext_ray_trainer.py .
    
    RUN chown -R ray:ray .
    ENV PYTHONPATH=/app/maxtext/src:/app/maxtext:/app
    USER ray
  4. 构建 Docker 映像、为其添加标记并将其推送到 Artifact Registry:

    export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest
    gcloud builds submit --tag ${DOCKER_IMAGE}
    

训练模型

  1. 将以下示例清单保存为 maxtext-tpu-cluster.yaml

    apiVersion: ray.io/v1
    kind: RayCluster
    metadata:
      name: maxtext-tpu-cluster
    spec:
      headGroupSpec:
        rayStartParams: {}
        template:
          metadata:
            annotations:
              gke-gcsfuse/volumes: "true"
              gke-gcsfuse/cpu-limit: "0"
              gke-gcsfuse/memory-limit: "0"
              gke-gcsfuse/ephemeral-storage-limit: "0"
          spec:
            serviceAccountName: ${KSA_NAME}
            containers:
              - name: ray-head
                image: ${DOCKER_IMAGE}
                imagePullPolicy: IfNotPresent
                ports:
                - containerPort: 6379
                  name: gcs-server
                - containerPort: 8265
                  name: dashboard
                - containerPort: 10001
                  name: client
                resources:
                  limits:
                    memory: "16Gi"
                  requests:
                    cpu: "8"
                    memory: "16Gi"
                volumeMounts:
                - name: gcs-fuse-csi-ephemeral
                  mountPath: /data
                - name: dshm
                  mountPath: /dev/shm
            volumes:
            - name: gcs-fuse-cache
              emptyDir:
                medium: Memory
            - name: dshm
              emptyDir:
                medium: Memory
            - name: gcs-fuse-csi-ephemeral
              csi:
                driver: gcsfuse.csi.storage.gke.io
                volumeAttributes:
                  bucketName: ${GS_BUCKET}
                  mountOptions: "implicit-dirs"
      workerGroupSpecs:
        - replicas: 1
          numOfHosts: 4
          groupName: tpu-group
          rayStartParams: {}
          template:
            metadata:
              annotations:
                gke-gcsfuse/volumes: "true"
                gke-gcsfuse/cpu-limit: "0"
                gke-gcsfuse/memory-limit: "0"
                gke-gcsfuse/ephemeral-storage-limit: "0"
            spec:
              serviceAccountName: ${KSA_NAME}
              containers:
                - name: ray-worker
                  image: ${DOCKER_IMAGE}
                  imagePullPolicy: IfNotPresent
                  resources:
                    limits:
                      memory: 200G
                      google.com/tpu: "4"
                    requests:
                      cpu: "8"
                      memory: 200G
                      google.com/tpu: "4"
                  env:
                    - name: JAX_PLATFORMS
                      value: tpu
                    - name: ENABLE_PJRT_COMPATIBILITY
                      value: "true"
                  volumeMounts:
                  - name: gcs-fuse-csi-ephemeral
                    mountPath: /data
                  - name: dshm
                    mountPath: /dev/shm
              volumes:
              - name: gcs-fuse-cache
                emptyDir:
                  medium: Memory
              - name: dshm
                emptyDir:
                  medium: Memory
              - name: gcs-fuse-csi-ephemeral
                csi:
                  driver: gcsfuse.csi.storage.gke.io
                  volumeAttributes:
                    bucketName: ${GS_BUCKET}
                    mountOptions: "implicit-dirs"
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4

    上述 RayCluster 规范会创建一个 TPU 工作器组,每个副本包含四个工作器 (numOfHosts: 4)。每个工作器请求 4 个 TPU 芯片 (google.com/tpu: "4")。工作器将调度到运行 TPU Trillium (tpu-v6e-slice) 的节点上,该节点是同一并置多主机切片的一部分。KubeRay 会以原子方式扩缩所有四个 worker,并且 GKE 会通过变更网络钩子来引导所需的 JAX 环境变量以及用于调度的 Pod 亲和性。

  2. 如需在 YAML 文件中配置所需的值,请使用 envsubst 创建 RayCluster:

    envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -
    
  3. 验证集群是否已准备就绪并正在运行:

    kubectl get rayclusters maxtext-tpu-cluster
    

    输出应类似如下所示:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY        GPUS   STATUS   AGE
    maxtext-tpu-cluster   4                 4                   40     798027216Ki   0      ready    11m
    
  4. 如需通过 Ray 头服务访问 Ray 信息中心,请建立端口转发会话:

    kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &
    
  5. 验证 RayCluster 是否可从本地环境访问:

    ray list nodes --address http://localhost:8265
    

    输出应类似如下所示:

    ======== List: 2025-09-13 03:53:16.988269 ========
    Stats:
    ------------------------------
    Total: 5
    Table:
    ------------------------------
        NODE_ID                                                   NODE_IP    IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                  LABELS
    0  92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56  10.84.0.9
    (...)
    
  6. 将 JaxTrainer 脚本提交到 RayCluster,并检查 RayJob 是否成功完成:

    ray job submit \
      --address http://localhost:8265 \
      -- python /app/maxtext_ray_trainer.py \
          /app/maxtext/src/MaxText/configs/base.yml \
           base_output_directory=/data/ \
          dataset_type=synthetic \
          per_device_batch_size=1 \
          max_target_length=4096 \
          model_name=llama3-8b \
          steps=100 \
          ici_fsdp_parallelism=4 \
          ici_tensor_parallelism=4 \
          run_name=rayjob-8b-4096-tp4-4x4
    

    上述命令会提交 Python 脚本,该脚本会调用 JaxTrainer Ray 代码到 RayCluster。ray job submit 命令包含一些特定于 MaxText 的实参,用于传递给模型配置。

    在终端中,您应该会看到类似如下所示的输出:

    (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]
    
    ------------------------------------------
    Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded
    ------------------------------------------
    

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留该项目但删除各个资源。

  1. 删除 RayCluster:

    kubectl delete raycluster maxtext-tpu-cluster
    
  2. 删除 GKE 集群:

    gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE
    
  3. 删除 Cloud Storage 存储桶:

    gsutil rm -r gs://${GS_BUCKET}
    
  4. 删除 Artifact Registry 代码库:

    gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
    

后续步骤