本教程介绍了如何使用 MaxText、Ray Train 和 TPU 在 Google Kubernetes Engine (GKE) 上训练 Llama 3 8B 大语言模型 (LLM)。
本教程提供了一个完整的端到端演示,从配置必要的云基础架构到在多主机 TPU 上提交并成功运行训练工作负载。
本教程适用于希望了解如何在分布式多主机 TPU 切片上训练大型模型的平台管理员和运维人员,以及数据和 AI 专家。
背景
GKE、KubeRay、MaxText 和 TPU 的组合为大规模模型训练提供了一个强大且可扩缩的平台。本部分介绍本指南中使用的关键技术。
JAX
JAX 是一个面向加速器的数组计算和程序转换 Python 库,专为高性能数值计算和大规模机器学习而设计。
JAX 提供了一个可扩展的系统,用于转换 jax.grad、jax.jit 和 jax.vmap 等数值函数,利用 XLA 编译器创建高度优化的代码,可在 GPU 和 TPU 等加速器上高效扩展。JAX 的核心优势在于其可组合性,这使得用户能够组合这些转换来构建复杂的、高性能的数值程序,以进行分布式执行。
MaxText
MaxText 是一种高性能的开源大语言模型 (LLM),旨在实现可伸缩性和可自定义性。MaxText 基于 JAX 构建,并经过优化,可在 Cloud TPU 和 GPU 上高效运行。
TPU
张量处理单元 (TPU) 是 Google 专为优化机器学习工作负载而定制设计的加速器。与通用 CPU 或并行处理 GPU 不同,TPU 专门针对深度学习基础中的大规模矩阵和张量计算进行了高度优化,因此能够高效完成此特定任务。TPU 的主要优势在于大规模性能。
本教程使用 TPU Trillium,这是第六代 TPU。 如需了解详情,请参阅使用 TPU Trillium 的优势。
KubeRay
KubeRay 是一个 Kubernetes 操作器,可提供一种在 Kubernetes 上部署、管理和监控 Ray 应用的统一方式。KubeRay 操作器通过 Ray on GKE 插件进行安装和管理,这是在 GKE 上部署和管理 Ray 集群的推荐方法。
目标
本教程介绍了如何执行以下操作:
- 设置具有多主机 TPU 节点池的 GKE 集群。
- 配置 KubeRay 以管理分布式训练环境。
- 构建包含 MaxText、Ray 和 JAX 依赖项的自定义 Docker 映像。
- 创建一个 Python 训练脚本,该脚本使用 Ray Train 的
JaxTrainer在 TPU 切片中编排 MaxText 训练循环。 - 定义
RayCluster自定义资源,以预配具有必要 TPU 资源的主节点和工作器节点。 - 将训练作业提交给
RayCluster并监控其进度。 - 使用 Cloud Storage 存储模型检查点。
准备工作
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
Install the Google Cloud CLI.
-
如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI。
-
如需初始化 gcloud CLI,请运行以下命令:
gcloud init -
Create or select a Google Cloud project.
Roles required to select or create a project
- Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
-
Create a project: To create a project, you need the Project Creator
(
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles.
-
Create a Google Cloud project:
gcloud projects create PROJECT_ID
Replace
PROJECT_IDwith a name for the Google Cloud project you are creating. -
Select the Google Cloud project that you created:
gcloud config set project PROJECT_ID
Replace
PROJECT_IDwith your Google Cloud project name.
-
Verify that billing is enabled for your Google Cloud project.
-
Enable the required API:
Roles required to enable APIs
To enable APIs, you need the Service Usage Admin IAM role (
roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission. Learn how to grant roles.gcloud services enable container.googleapis.com
-
Install the Google Cloud CLI.
-
如果您使用的是外部身份提供方 (IdP),则必须先使用联合身份登录 gcloud CLI。
-
如需初始化 gcloud CLI,请运行以下命令:
gcloud init -
Create or select a Google Cloud project.
Roles required to select or create a project
- Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
-
Create a project: To create a project, you need the Project Creator
(
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles.
-
Create a Google Cloud project:
gcloud projects create PROJECT_ID
Replace
PROJECT_IDwith a name for the Google Cloud project you are creating. -
Select the Google Cloud project that you created:
gcloud config set project PROJECT_ID
Replace
PROJECT_IDwith your Google Cloud project name.
-
Verify that billing is enabled for your Google Cloud project.
-
Enable the required API:
Roles required to enable APIs
To enable APIs, you need the Service Usage Admin IAM role (
roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission. Learn how to grant roles.gcloud services enable container.googleapis.com
-
Grant roles to your user account. Run the following command once for each of the following IAM roles:
roles/container.admin, roles/iam.serviceAccountAdmingcloud projects add-iam-policy-binding PROJECT_ID --member="user:USER_IDENTIFIER" --role=ROLE
Replace the following:
PROJECT_ID: Your project ID.USER_IDENTIFIER: The identifier for your user account. For example,myemail@example.com.ROLE: The IAM role that you grant to your user account.
- 由于本教程使用 TPU Trillium (v6e),请选择可用的区域或可用区。如需了解详情,请参阅 Cloud TPU 配额。
准备环境
在本教程中,您将使用 Cloud Shell。gcloudCloud Shellhelm 预安装了本教程中使用的 kubectl、 和 命令行工具。
前往 Google Cloud 控制台。
在 Google Cloud 控制台窗口顶部,点击激活 Cloud Shell
按钮。一个 Cloud Shell 会话随即会在Google Cloud 控制台中的新框架内打开,并显示命令行提示符。
创建并激活 Python 虚拟环境:
python3 -m venv ray-env source ray-env/bin/activate安装 Ray CLI 和其他依赖项:
pip install "ray[default]==2.49.1"设置以下环境变量:
export PROJECT_ID=$(gcloud config get project) export PROJECT_NUMBER=$(gcloud projects describe ${PROJECT_ID} --format="value(projectNumber)") export GS_BUCKET=GS_BUCKET export KSA_NAME=KSA_NAME export NAMESPACE=default export CLUSTER_NAME=CLUSTER_NAME export REGION=REGION export ZONE=ZONE export ARTIFACT_REGISTRY=ARTIFACT_REGISTRY替换以下内容:
GS_BUCKET:Cloud Storage 存储桶的名称。KSA_NAME:Kubernetes ServiceAccount 的名称。CLUSTER_NAME:新集群的名称。REGION:您的 TPU Trillium 容量可用的区域。ZONE:您的 TPU Trillium 容量可用的可用区。如需了解详情,请参阅 GKE 中的 TPU 可用性。ARTIFACT_REGISTRY:Artifact Registry 代码库的名称。
创建 GKE 集群
您可以在 GKE Autopilot 或 Standard 集群中的 TPU 上配置 KubeRay。我们建议您使用 Autopilot 集群获得全托管式 Kubernetes 体验。如需选择最适合您的工作负载的 GKE 操作模式,请参阅GKE 操作模式简介。
Autopilot
在 Cloud Shell 中,运行以下命令:
gcloud container clusters create-auto $CLUSTER_NAME \ --enable-ray-operator \ --machine-type=n1-standard-16 \ --location=$REGION如需与集群通信,请配置
kubectl:gcloud container clusters get-credentials CLUSTER_NAME \ --location=$ZONE
标准
在 Cloud Shell 中,运行以下命令以创建启用 Ray operator 插件的 Standard 集群:
gcloud container clusters create $CLUSTER_NAME \ --addons=RayOperator \ --addons GcsFuseCsiDriver \ --machine-type=n1-standard-16 \ --workload-pool=$PROJECT_ID.svc.id.goog \ --location=$ZONE此命令还会启用
GcsFuseCsiDriver,从而允许 Pod 将 Cloud Storage 存储分区作为本地文件系统进行装载。集群创建可能需要几分钟的时间。如需与集群通信,请配置
kubectl:gcloud container clusters get-credentials CLUSTER_NAME \ --location=LOCATION创建多主机 TPU 切片节点池:
gcloud container node-pools create v6e-16 \ --location=$ZONE \ --cluster=$CLUSTER_NAME \ --machine-type=ct6e-standard-4t \ --threads-per-core=1 \ --tpu-topology=4x4 \ --num-nodes=4
GKE 会预配一个由四个 TPU Trillium (v6e) 虚拟机组成的节点池,这些虚拟机共同配置为具有 4x4 拓扑的多主机 TPU 切片,可用于分布式训练工作负载。
启用了 Ray 操作器的 GKE 集群会自动在集群中安装 KubeRay 和 KubeRay TPU webhook。
配置 Cloud Storage 存储分区和服务账号
创建一个 Cloud Storage 存储分区,用于在多主机 TPU 节点之间共享检查点。
gsutil mb -p ${PROJECT_ID} -c STANDARD -l ${REGION} gs://${GS_BUCKET}如需启用对 Cloud Storage 存储分区的访问权限,请创建 Kubernetes ServiceAccount:
kubectl create serviceaccount ${KSA_NAME} --namespace ${NAMESPACE}如需启用对 Cloud Storage 存储分区的访问权限,请向服务账号添加所需的 IAM 政策绑定:
gcloud storage buckets add-iam-policy-binding gs://${GS_BUCKET} \ --member "principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}" \ --role "roles/storage.objectUser"
创建训练脚本
以下脚本使用 Ray Train 的 JaxTrainer 运行分布式 MaxText 训练作业。该脚本可为多主机 TPU 切片节点池配置训练环境,并在每个工作器节点上运行 MaxText 训练作业。train_loop_per_worker 函数封装了 MaxText 主要入口点,并使用 Ray 的分布式调度程序在多主机 TPU 切片上执行 MaxText 训练器。
将以下 Python 脚本保存为
maxtext_ray_trainer.py:如需托管自定义映像,请创建 Artifact Registry 制品库:
gcloud artifacts repositories create ${ARTIFACT_REGISTRY} \ --repository-format=docker --location=${REGION} && \ gcloud auth configure-docker ${REGION}-docker.pkg.dev如需构建包含用于训练的 Ray 和 MaxText 依赖项的映像,请创建
Dockerfile:构建 Docker 映像、为其添加标记并将其推送到 Artifact Registry:
export DOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latest gcloud builds submit --tag ${DOCKER_IMAGE}
训练模型
将以下示例清单保存为
maxtext-tpu-cluster.yaml:上述 RayCluster 规范会创建一个 TPU 工作器组,每个副本包含四个工作器 (
numOfHosts: 4)。每个工作器请求 4 个 TPU 芯片 (google.com/tpu: "4")。工作器将调度到运行 TPU Trillium (tpu-v6e-slice) 的节点上,该节点是同一并置多主机切片的一部分。KubeRay 会以原子方式扩缩所有四个 worker,并且 GKE 会通过变更网络钩子来引导所需的 JAX 环境变量以及用于调度的 Pod 亲和性。如需在 YAML 文件中配置所需的值,请使用
envsubst创建 RayCluster:envsubst < maxtext-tpu-cluster.yaml | kubectl apply -f -验证集群是否已准备就绪并正在运行:
kubectl get rayclusters maxtext-tpu-cluster输出应类似如下所示:
NAME DESIRED WORKERS AVAILABLE WORKERS CPUS MEMORY GPUS STATUS AGE maxtext-tpu-cluster 4 4 40 798027216Ki 0 ready 11m如需通过 Ray 头服务访问 Ray 信息中心,请建立端口转发会话:
kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265:8265 2>&1 >/dev/null &验证 RayCluster 是否可从本地环境访问:
ray list nodes --address http://localhost:8265输出应类似如下所示:
======== List: 2025-09-13 03:53:16.988269 ======== Stats: ------------------------------ Total: 5 Table: ------------------------------ NODE_ID NODE_IP IS_HEAD_NODE STATE STATE_MESSAGE NODE_NAME RESOURCES_TOTAL LABELS 0 92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56 10.84.0.9 (...)将 JaxTrainer 脚本提交到 RayCluster,并检查 RayJob 是否成功完成:
ray job submit \ --address http://localhost:8265 \ -- python /app/maxtext_ray_trainer.py \ /app/maxtext/src/MaxText/configs/base.yml \ base_output_directory=/data/ \ dataset_type=synthetic \ per_device_batch_size=1 \ max_target_length=4096 \ model_name=llama3-8b \ steps=100 \ ici_fsdp_parallelism=4 \ ici_tensor_parallelism=4 \ run_name=rayjob-8b-4096-tp4-4x4上述命令会提交 Python 脚本,该脚本会调用 JaxTrainer Ray 代码到 RayCluster。
ray job submit命令包含一些特定于 MaxText 的实参,用于传递给模型配置。在终端中,您应该会看到类似如下所示的输出:
(RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster] ------------------------------------------ Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded ------------------------------------------
清理
为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留该项目但删除各个资源。
删除 RayCluster:
kubectl delete raycluster maxtext-tpu-cluster删除 GKE 集群:
gcloud container clusters delete $CLUSTER_NAME --zone=$ZONE删除 Cloud Storage 存储桶:
gsutil rm -r gs://${GS_BUCKET}删除 Artifact Registry 代码库:
gcloud artifacts repositories delete ${ARTIFACT_REGISTRY} --location=${REGION} --quiet
后续步骤
- 了解 Ray on Kubernetes。
- 了解如何在 GKE 上通过 TPU 部署 vLLM。
- 了解如何在 GKE 上使用 TPU 部署 SDXL。
- 详细了解 GKE 中的 TPU。