本页面介绍了如何使用 Cloud TPU 多切片配置在 Google Kubernetes Engine (GKE) 中部署工作负载,以进行经济高效的大规模训练。
在 GKE 中配置多切片之前,您应该熟悉以下概念:
什么是 TPU 多切片
TPU 多切片是 TPU 切片中虚拟机的架构组织,其中两个或更多 Cloud TPU 切片通过数据中心网络 (DCN) 进行通信。多切片支持全栈、经济高效的大规模的训练,具有近线性扩容能力,可达到数万个 TPU 芯片。在多切片配置中,GKE 在多个 TPU 切片上部署多切片工作负载。切片中的 TPU 芯片之间的通信通过芯片间互连 (ICI) 进行。切片之间的通信是通过 DCN 进行的。
如果您的作业太大而无法放在单个 TPU 切片上,我们建议您使用多切片。
GKE 中的多切片可用性
- Standard 在 1.27.4-gke.900 版及更高版本中支持多切片。
- Autopilot 在 1.29.2-gke.1521000 及更高版本中支持多切片。
- 多切片支持 JAX 和 PyTorch 框架。支持的最低 JAX 版本为 2.1。
- 多切片仅支持多主机 TPU 切片节点池。例如,您不能将 Multislice 与具有
2x2x1
拓扑的ct4p-hightpu-4t
或具有2x2
拓扑的ct5lp-hightpu-4t
结合使用,因为这些是单主机 TPU 切片节点池。 - 多切片仅支持同步多控制器训练。
- 多切片工作负载只能在共享相同 TPU 类型、大小和拓扑的 TPU 切片中运行。
准备工作
在开始之前,请确保您已执行以下任务:
- 启用 Google Kubernetes Engine API。 启用 Google Kubernetes Engine API
- 如果您要使用 Google Cloud CLI 执行此任务,请安装并初始化 gcloud CLI。 如果您之前安装了 gcloud CLI,请运行
gcloud components update
以获取最新版本。
- 创建一个 Standard 集群或 Autopilot 集群来运行支持多切片的版本。如需查看受支持的版本,请参阅 GKE 中的多切片可用性。
- 确保您的项目具有足够的配额,以便用于 GKE 中的 Cloud TPU。
- 安装 JobSet v0.2.3 或更高版本。
在多切片上运行工作负载
本部分介绍了如何在多切片上运行工作负载。如果您使用 GKE Autopilot 模式,请跳到运行多切片工作负载部分。默认情况下,运行 1.29.2-gke.1521000 版或更高版本的 Autopilot 集群会启用 TPU。
准备 Standard 模式节点池
本部分包括以下步骤:
- 创建三个多主机 TPU 切片节点池
- 验证节点池状态
创建 TPU 切片节点池
您可以创建多个多主机 TPU 切片节点池。在本指南中,创建三个多主机 TPU 切片节点池以运行多切片工作负载。您可以使用 Google Cloud CLI、Terraform 或 Google Cloud 控制台创建多主机 TPU 切片节点池。
gcloud
gcloud container node-pools create POOL_NAME \
--location=LOCATION \
--cluster=CLUSTER_NAME \
--node-locations=NODE_ZONE \
--machine-type=MACHINE_TYPE \
--tpu-topology=TPU_TOPOLOGY \
--num-nodes=NUM_NODES \
[--spot \]
[--enable-autoscaling \
--max-nodes MAX_NODES]
[--reservation-affinity=specific \
--reservation=RESERVATION_NAME]
请替换以下内容:
POOL_NAME
:新节点池的名称。LOCATION
:基于您要使用的 TPU 版本的可用区名称:- 对于 TPU v4,请使用
us-central2-b
。 - 以
ct5l-
开头的 TPU v5e 机器类型绝不会是多主机。 - 对于以
ct5lp-
开头的 TPU v5e 机器类型,请使用us-west1-c
、us-west4-a
、us-west4-b
、us-central1-a
、us-east1-c
、us-east5-b
或europe-west4-a
。 - 对于以
ct5p-
开头的 TPU v5p 机器类型,请使用us-east1-d
、us-east5-a
或us-east5-c
。
如需了解详情,请参阅 GKE 中的 TPU 可用性。
- 对于 TPU v4,请使用
CLUSTER_NAME
:集群的名称。NODE_ZONE
:GKE 在其中创建节点池的一个或多个可用区的英文逗号分隔列表。MACHINE_TYPE
:用于节点的机器类型。如需详细了解可用机器类型,请参阅 TPU 配置映射。TPU_TOPOLOGY
:TPU 切片的物理拓扑。拓扑格式取决于 TPU 版本,如下所示:- TPU v4 或 v5p:使用 3 元组 (
{A}x{B}x{C}
) 定义拓扑,例如4x4x4
。 - TPU v5e:使用 2 元组 (
{A}x{B}
) 定义拓扑,例如2x2
。
如需了解详情,请参阅拓扑。
- TPU v4 或 v5p:使用 3 元组 (
NUM_NODES
:节点池中的节点数。该值必须为零或TPU_TOPOLOGY
({A}x{B}x{C}
) 中定义的值的乘积除以每个虚拟机中的芯片数量。对于多主机 TPU v4 和 TPU v5e,每个虚拟机中的芯片数量为 4。因此,如果TPU_TOPOLOGY
为2x4x4
(每个虚拟机中有四个芯片的 TPU v4),则NUM_NODES
为 32/4,等于 8。
(可选)您还可以使用以下标志:
RESERVATION_NAME
:GKE 在创建节点池时使用的预留的名称。如果您省略此标志,则 GKE 会使用可用的 TPU 切片节点池。如需详细了解 TPU 预留,请参阅 TPU 预留。--spot
:设置节点池以对 TPU 切片节点使用 Spot 虚拟机。创建节点池后,您将无法更改此设置。如需了解详情,请参阅 Spot 虚拟机。--enable-autoscaling
:创建启用了自动扩缩功能的节点池。当 GKE 扩缩多主机 TPU 切片节点池时,它会以原子方式将节点池从零扩容到大小上限。MAX_NODES
:节点池的大小上限。如果提供了--enable-autoscaling
,则必须使用--max-nodes
标志,且值必须等于TPU_TOPOLOGY
中定义的值的乘积 ({A}x{B}x{C}
) 除以每个虚拟机中的芯片数量。
Terraform
- 确保您使用
google
提供程序 4.84.0 版或更高版本。 将以下块添加到 Terraform 配置中:
resource "google_container_node_pool" "NODE_POOL_RESOURCE_NAME" { provider = google project = PROJECT_ID cluster = CLUSTER_NAME name = POOL_NAME location = CLUSTER_LOCATION node_locations = [NODE_ZONES] initial_node_count = NUM_NODES autoscaling { max_node_count = MAX_NODES location_policy = "ANY" } node_config { machine_type = MACHINE_TYPE reservation_affinity { consume_reservation_type = "SPECIFIC_RESERVATION" key = "compute.googleapis.com/reservation-name" values = [RESERVATION_LABEL_VALUES] } spot = true } placement_policy { type = "COMPACT" tpu_topology = TPU_TOPOLOGY } }
请替换以下内容:
NODE_POOL_RESOURCE_NAME
:Terraform 模板中的节点池资源的名称。PROJECT_ID
:您的项目 ID。CLUSTER_NAME
:要在其中添加节点池的现有集群的名称。POOL_NAME
:要创建的节点池的名称。CLUSTER_LOCATION
:集群的计算位置。我们建议您使用区域级集群,以提高 Kubernetes 控制平面的可靠性。您还可以使用可用区级集群。如需了解详情,请参阅选择 TPU 版本和拓扑。NODE_ZONES
:GKE 在其中创建节点池的一个或多个可用区的英文逗号分隔列表。NUM_NODES
:节点池中的节点数。值必须为零或 TPU 芯片数量的乘积除以 4,因为在多主机 TPU 切片中,每个 TPU 切片节点都有 4 个芯片。例如,如果TPU_TOPOLOGY
为4x8
,则有 32 个芯片,这意味着NUM_NODES
必须为 8。如需详细了解 TPU 拓扑,请使用 TPU 配置映射中的表格。TPU_TOPOLOGY
:指示所需的 TPU 切片物理拓扑。拓扑格式取决于您使用的 TPU 版本:- 对于 TPU v4:使用 3 元组 (
{A}x{B}x{C}
) 定义拓扑,例如4x4x4
。 - 对于 TPU v5e:使用 2 元组 (
{A}x{B}
) 定义拓扑,例如2x2
。
- 对于 TPU v4:使用 3 元组 (
(可选)您还可以使用以下变量:
RESERVATION_NAME
:如果您使用 TPU 预留,则这是创建节点池时使用的预留资源的标签列表。如需详细了解如何填充reservation_affinity
字段中的RESERVATION_LABEL_VALUES
,请参阅 Terraform 提供程序。autoscaling
:创建启用了自动扩缩功能的节点池。当 GKE 扩缩多主机 TPU 切片节点池时,它会以原子方式将节点池从零扩容到大小上限。MAX_NODES
:节点池的大小上限。该值必须等于TPU_TOPOLOGY
({A}x{B}x{C}
) 中定义的值的乘积除以每个虚拟机中的芯片数量。
spot
:可让节点池对 TPU 切片节点使用 Spot 虚拟机。创建节点池后,您将无法更改此设置。如需了解详情,请参阅 Spot 虚拟机。
控制台
如需创建具有 TPU 的节点池,请执行以下操作:
转到 Google Cloud 控制台中的 Google Kubernetes Engine 页面。
在集群列表中,点击您要修改的集群的名称。
点击 add_box 添加节点池。
在节点池详情部分中,勾选指定节点位置复选框。
根据您要使用的 TPU 版本选择可用区:
- 对于 TPU v4,请使用
us-central2-b
。 - 以
ct5l-
开头的 TPU v5e 机器类型绝不会是多主机。 - 对于以
ct5lp-
开头的 TPU v5e 机器类型,请使用us-west1-c
、us-west4-a
、us-west4-b
、us-central1-a
、us-east1-c
、us-east5-b
或europe-west4-a
。 - 对于以
ct5p-
开头的 TPU v5p 机器类型,请使用us-east1-d
、us-east5-a
或us-east5-c
。
- 对于 TPU v4,请使用
在导航窗格中,点击节点。
在机器配置部分中,选择 TPU。
在系列下拉菜单中,选择以下选项之一:
- CT4P:适用于 TPU v4。
- CT5LP:适用于 TPU v5e。
在机器类型下拉菜单中,选择要用于节点的机器的名称。使用 TPU 配置映射表可了解如何定义用于创建多主机 TPU 切片节点池的机器类型和 TPU 拓扑。
在 TPU 拓扑下拉菜单中,选择 TPU 切片的物理拓扑。
在需要更改对话框中,点击进行更改。
确保启动磁盘类型为标准永久性磁盘或 SSD 永久性磁盘。
(可选)选中在 Spot 虚拟机上启用节点复选框,以对节点池中的节点使用 Spot 虚拟机。
点击创建。
验证节点池状态
获取凭据,以便使用
kubectl
访问集群:gcloud container clusters get-credentials CLUSTER_NAME \ --project=PROJECT_ID
替换以下内容:
CLUSTER_NAME
:集群的名称。PROJECT_ID
:您的项目 ID。
在 Cloud Shell 中使用
kubectl
查看 TPU 切片节点:kubectl get nodes -l cloud.google.com/gke-tpu-accelerator=TPU_ACCELERATOR \ -l cloud.google.com/gke-tpu-topology=TPU_TOPOLOGY
替换以下内容:
TPU_ACCELERATOR
:创建节点池时使用的 TPU 加速器的类型。例如tpu-v4-podslice
、tpu-v5-lite-device
或tpu-v5-lite-podslice
。TPU_TOPOLOGY
:TPU 切片的物理拓扑。
输出类似于以下内容:
NAME STATUS ROLES AGE VERSION gke-tpu-20ee2cce-5tv6 Ready <none> 34h v1.28.1-gke.1066000
运行多切片工作负载
在本部分中,您将运行一个 JAX 工作负载(其中显示了 TPU 切片中的全球 TPU 芯片数量),然后退出。
如需运行 JAX 工作负载,请执行以下操作:
创建以下
tpu-multislice.yaml
清单:Autopilot
apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-job annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: NUM_SLICES template: spec: parallelism: NUM_NODES completions: NUM_NODES backoffLimit: 0 template: spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: ACCELERATOR_TYPE cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 - containerPort: 8431 command: - bash - -c - | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Global device count:", jax.device_count())' sleep 60 resources: limits: google.com/tpu: NUM_CHIPS
标准
apiVersion: jobset.x-k8s.io/v1alpha2 kind: JobSet metadata: name: multislice-job annotations: alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool spec: failurePolicy: maxRestarts: 4 replicatedJobs: - name: slice replicas: NUM_SLICES template: spec: parallelism: NUM_NODES completions: NUM_NODES backoffLimit: 0 template: spec: hostNetwork: true dnsPolicy: ClusterFirstWithHostNet nodeSelector: cloud.google.com/gke-tpu-accelerator: ACCELERATOR_TYPE cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY containers: - name: jax-tpu image: python:3.8 ports: - containerPort: 8471 - containerPort: 8080 - containerPort: 8431 securityContext: privileged: true command: - bash - -c - | pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Global device count:", jax.device_count())' sleep 60 resources: limits: google.com/tpu: NUM_CHIPS
替换以下内容:
NUM_SLICES
:TPU 切片节点池的数量。在这种情况下,NUM_SLICES
等于3
。ACCELERATOR_TYPE
:创建节点池时使用的 TPU 加速器的类型。例如tpu-v4-podslice
、tpu-v5-lite-device
或tpu-v5-lite-podslice
。TPU_TOPOLOGY
:TPU 切片的物理拓扑。例如4x4x4
或2x2
,具体取决于 TPU 版本。NUM_NODES
:节点池中的节点数。 值必须为零或TPU_TOPOLOGY
({A}x{B}x{C}
) 中定义的值的乘积除以每个虚拟机中的 TPU 芯片数量。对于多主机 TPU v4,每个虚拟机中的 TPU 芯片数量为 4。对于多主机 TPU v5e,每个虚拟机中的 TPU 芯片数量为 1、4 或 8。因此,如果TPU_TOPOLOGY
为2x4x4
(每个虚拟机中具有四个 TPU 芯片的 TPU v4),则NUM_NODES
为 32/4,等于 8。NUM_CHIPS
:对于多主机 TPU v4,每个虚拟机中的 TPU 芯片数量为 4。对于多主机 TPU v5e,每个虚拟机中的 TPU 芯片数量为 1、4 或 8。如需了解详情,请参阅 TPU 切片中虚拟机上的 TPU 芯片。
在此清单中:
- JobSet 是与 JobSet 名称同名的无头服务,在本例中为
multislice-job
。 maxRestarts: 4
表示在子作业失败时 GKE 重启 JobSet 的最大次数。如果 JobSet 重启达到定义的最大值,则会将 JobSet 标记为失败。parallelism
和completions
字段等于每个节点池中的节点数。backoff
为 0,因为多切片仅支持同步多控制器训练。必须设置为 0。当任何 pod 失败时,使作业失败。- 亲和性部分的值可确保在一组多切片中只有一个 TPU 多切片工作负载运行。
containerPort: 8080
是 MXLA 协调器的端口containerPort: 8431
是用于导出 TPU 用量指标的端口securityContext: privileged: true
表示节点已启用特权模式以访问 TPU。GKE 1.28 或更高版本中的节点无需启用特权模式即可访问 TPU。如需了解详情,请参阅在不使用特权模式的情况下运行容器。
应用清单:
kubectl apply -f tpu-multislice.yaml
确认工作负载已被允许:
kubectl get jobsets
输出类似于以下内容:
NAME RESTARTS COMPLETED AGE multislice-job 3s
监控预配 Pod 的状态:
kubectl get pods
输出类似于以下内容:
NAME READY STATUS RESTARTS AGE multislice-job-slice-0-0-wzq9t 0/1 Completed 0 2m31s multislice-job-slice-0-1-zf4dp 0/1 Completed 0 2m30s multislice-job-slice-1-0-hbfn5 0/1 Completed 0 2m31s multislice-job-slice-1-1-45fgl 0/1 Completed 0 2m30s multislice-job-slice-2-0-wjbp4 0/1 Completed 0 2m30s multislice-job-slice-2-1-lwnvs 0/1 Completed 0 2m30s
multislice-job
JobSet 会安排、创建 Pod,然后运行 Pod 以完成操作。Pod 名称的格式为 <jobsetName>-<jobName>-<jobReplicaIndex>-<randomSuffix>
。jobsetName
前缀决定了 Pod 所属的 JobSet。
其他配置
以下各部分介绍了可应用于多切片的其他配置。
在 GKE Standard Pod 上启用 hostNetwork
为了提升 TPU 切片之间的网络性能,我们建议您开启 hostNetworking
。在 Pod 规范中使用 hostNetwork: true
跳过所有 Kubernetes 网络堆栈,让 Kubernetes Pod 直接使用主机网络进行虚拟机之间的通信。
如需启用 hostNetworking
,请从 Pod 规范中移除以下两行:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
如需继续使用 podHostnames
通过 hostNetwork
发现工作器节点,请设置 dnsPolicy: ClusterFirstWithHostNet
。如果您要运行自动恢复训练作业,并且需要使用相同的名称来重新加载相同的检查点,则这一点很重要。
日志记录
如果您已在集群中启用 GKE 系统日志记录,则由 GKE 节点(包括 TPU 切片节点)上运行的容器发出的日志会显示在 Logs Explorer 中。
您可以使用 Logs Explorer 通过以下过滤条件来查看工作负载的容器日志,以此查看 GKE 中的日志:
resource.type="k8s_container"
resource.labels.cluster_name=CLUSTER_NAME
labels."k8s-pod/jobset_sigs_k8s_io/jobset-name"=JOBSET_NAME
对 TPU 切片和工作器使用以下过滤条件:
resource.type="k8s_container"
resource.labels.cluster_name=CLUSTER_NAME
labels."k8s-pod/jobset_sigs_k8s_io/jobset-name"=JOBSET_NAME
resource.labels.pod_name:<jobSetName>-<replicateJobName>-<job-index>-<worker-index>
可观测性和指标
除了一般的 TPU 指标之外,还有 4 个多切片专用 TPU 运行时指标。GKE 1.29.1-gke.1016000 版或更高版本提供了这些指标。TPU 工作负载必须使用 JAX 0.4.24 版
以下是可用的多切片指标:
- DCN(数据中心网络)传输延迟时间:多切片流量的网络传输延迟时间分布。
- 总体延迟时间:多切片流量的端到端总体延迟时间分布。
- 主机到设备传输延迟时间:多切片流量的每个数据块的主机到设备传输延迟时间分布。
- 设备到主机传输延迟时间:多切片流量的每个数据块的设备到主机传输延迟时间分布。
这些指标位于 Kubernetes 容器 (k8s_container
) 架构中:
kubernetes.io/container/multislice/network/dcn_transfer_latencies
kubernetes.io/container/multislice/network/collective_end_to_end_latencies
kubernetes.io/container/multislice/accelerator/host_to_device_transfer_latencies
kubernetes.io/container/multislice/accelerator/device_to_host_transfer_latencies
TPU 切片与多切片
下表区分了 TPU 切片和多切片的架构组织:
TPU 切片 | 多切片 | |
---|---|---|
互连 | 工作负载在单个 TPU 切片上运行。切片中的所有 TPU 芯片与 ICI 连接。 | 工作负载在多个 TPU 切片上运行。切片中的通信通过 ICI 进行。切片之间的通信通过 DCN 进行。 |
支持的节点池 | 单主机 TPU 切片和多主机 TPU 切片 | 多主机 TPU 切片组 |
推荐的工作负载类型 | IndexedJob 或 JobSet | JobSet |