在 GKE 中部署 TPU 多片


本页面介绍了如何使用 Cloud TPU 多切片配置在 Google Kubernetes Engine (GKE) 中部署工作负载,以进行经济高效的大规模训练。

在 GKE 中配置多切片之前,请确保您熟悉以下概念:

  1. Cloud TPU 简介
  2. Cloud TPU 系统架构
  3. 关于 GKE 中的 TPU

什么是 TPU 多切片

TPU 多切片是 TPU 切片中虚拟机的架构组织,其中两个或更多 Cloud TPU 切片通过数据中心网络 (DCN) 进行通信。多切片支持全栈、经济高效的大规模的训练,具有近线性扩容能力,可达到数万个 TPU 芯片。在多切片配置中,GKE 在多个 TPU 切片上部署多切片工作负载。切片中的 TPU 芯片之间的通信通过芯片间互连 (ICI) 进行。切片之间的通信是通过 DCN 进行的。

如果您的作业太大而无法放在单个 TPU 切片上,我们建议您使用多切片。

GKE 中的多切片可用性

  • Standard 在 1.27.4-gke.900 版及更高版本中支持多切片。
  • Autopilot 在 1.29.2-gke.1521000 及更高版本中支持多切片。
  • 多切片支持 JAX 和 PyTorch 框架。支持的最低 JAX 版本为 2.1。
  • 多切片仅支持多主机 TPU 切片节点池。例如,您不能将 Multislice 与具有 2x2x1 拓扑的 ct4p-hightpu-4t 或具有 2x2 拓扑的 ct5lp-hightpu-4t 结合使用,因为这些是单主机 TPU 切片节点池。
  • 多切片仅支持同步多控制器训练。
  • 多切片工作负载只能在共享相同 TPU 类型、大小和拓扑的 TPU 切片中运行。

准备工作

在开始之前,请确保您已执行以下任务:

  • 启用 Google Kubernetes Engine API。
  • 启用 Google Kubernetes Engine API
  • 如果您要使用 Google Cloud CLI 执行此任务,请安装初始化 gcloud CLI。 如果您之前安装了 gcloud CLI,请运行 gcloud components update 以获取最新版本。

在多切片上运行工作负载

本部分介绍了如何在多切片上运行工作负载。如果您使用 GKE Autopilot 模式,请跳到运行多切片工作负载部分。默认情况下,运行 1.29.2-gke.1521000 版或更高版本的 Autopilot 集群会启用 TPU。

准备 Standard 模式节点池

本部分包括以下步骤:

  1. 创建三个多主机 TPU 切片节点池
  2. 验证节点池状态

创建 TPU 切片节点池

您可以创建多个多主机 TPU 切片节点池。在本指南中,创建三个多主机 TPU 切片节点池以运行多切片工作负载。您可以使用 Google Cloud CLI、Terraform 或 Google Cloud 控制台创建多主机 TPU 切片节点池。

gcloud

gcloud container node-pools create POOL_NAME \
    --location=LOCATION \
    --cluster=CLUSTER_NAME \
    --node-locations=NODE_ZONE \
    --machine-type=MACHINE_TYPE \
    --tpu-topology=TPU_TOPOLOGY \
    --num-nodes=NUM_NODES \
    [--spot \]
    [--enable-autoscaling \
      --max-nodes MAX_NODES]
    [--reservation-affinity=specific \
    --reservation=RESERVATION_NAME]

请替换以下内容:

  • POOL_NAME:新节点池的名称。
  • LOCATION:基于您要使用的 TPU 版本的可用区名称:

    • 对于 TPU v4,请使用 us-central2-b
    • ct5l- 开头的 TPU v5e 机器类型绝不会是多主机。
    • 对于以 ct5lp- 开头的 TPU v5e 机器类型,请使用 us-west1-cus-west4-aus-west4-bus-central1-aus-east1-cus-east5-beurope-west4-a
    • 对于以 ct5p- 开头的 TPU v5p 机器类型,请使用 us-east1-dus-east5-aus-east5-c

    如需了解详情,请参阅 GKE 中的 TPU 可用性

  • CLUSTER_NAME:集群的名称。

  • NODE_ZONE:GKE 在其中创建节点池的一个或多个可用区的英文逗号分隔列表。

  • MACHINE_TYPE:用于节点的机器类型。如需详细了解可用机器类型,请参阅 TPU 配置映射

  • TPU_TOPOLOGY:TPU 切片的物理拓扑。拓扑格式取决于 TPU 版本,如下所示:

    • TPU v4 或 v5p:使用 3 元组 ({A}x{B}x{C}) 定义拓扑,例如 4x4x4
    • TPU v5e:使用 2 元组 ({A}x{B}) 定义拓扑,例如 2x2

    如需了解详情,请参阅拓扑

  • NUM_NODES:节点池中的节点数。该值必须为零或 TPU_TOPOLOGY ({A}x{B}x{C}) 中定义的值的乘积除以每个虚拟机中的芯片数量。对于多主机 TPU v4 和 TPU v5e,每个虚拟机中的芯片数量为 4。因此,如果 TPU_TOPOLOGY2x4x4(每个虚拟机中有四个芯片的 TPU v4),则 NUM_NODES 为 32/4,等于 8。

(可选)您还可以使用以下标志:

  • RESERVATION_NAME:GKE 在创建节点池时使用的预留的名称。如果您省略此标志,则 GKE 会使用可用的 TPU 切片节点池。如需详细了解 TPU 预留,请参阅 TPU 预留
  • --spot:设置节点池以对 TPU 切片节点使用 Spot 虚拟机。创建节点池后,此设置将无法更改。如需了解详情,请参阅 Spot 虚拟机
  • --enable-autoscaling:创建启用了自动扩缩功能的节点池。当 GKE 扩缩多主机 TPU 切片节点池时,它会以原子方式将节点池从零扩容到大小上限。
    • MAX_NODES:节点池的大小上限。如果提供了 --enable-autoscaling,则必须使用 --max-nodes 标志,且值必须等于 TPU_TOPOLOGY 中定义的值的乘积 ({A}x{B}x{C}) 除以每个虚拟机中的芯片数量。

Terraform

  1. 确保您使用 google 提供程序 4.84.0 版或更高版本。
  2. 将以下块添加到 Terraform 配置中:

    resource "google_container_node_pool" "NODE_POOL_RESOURCE_NAME" {
      provider           = google
      project            = PROJECT_ID
      cluster            = CLUSTER_NAME
      name               = POOL_NAME
      location           = CLUSTER_LOCATION
      node_locations     = [NODE_ZONES]
      initial_node_count = NUM_NODES
    
      autoscaling {
        max_node_count = MAX_NODES
        location_policy      = "ANY"
      }
      node_config {
        machine_type = MACHINE_TYPE
        reservation_affinity {
          consume_reservation_type = "SPECIFIC_RESERVATION"
          key = "compute.googleapis.com/reservation-name"
          values = [RESERVATION_LABEL_VALUES]
        }
        spot = true
      }
    
      placement_policy {
        type = "COMPACT"
        tpu_topology = TPU_TOPOLOGY
      }
    }
    

    请替换以下内容:

    • NODE_POOL_RESOURCE_NAME:Terraform 模板中的节点池资源的名称。
    • PROJECT_ID:您的项目 ID。
    • CLUSTER_NAME:要在其中添加节点池的现有集群的名称。
    • POOL_NAME:要创建的节点池的名称。
    • CLUSTER_LOCATION:集群的计算位置。我们建议您使用区域级集群,以提高 Kubernetes 控制平面的可靠性。您还可以使用可用区级集群。如需了解详情,请参阅选择 TPU 版本和拓扑
    • NODE_ZONES:GKE 在其中创建节点池的一个或多个可用区的英文逗号分隔列表。
    • NUM_NODES:节点池中的节点数。值必须为零或 TPU 芯片数量的乘积除以 4,因为在多主机 TPU 切片中,每个 TPU 切片节点都有 4 个芯片。例如,如果 TPU_TOPOLOGY4x8,则有 32 个芯片,这意味着 NUM_NODES 必须为 8。如需详细了解 TPU 拓扑,请使用 TPU 配置映射中的表格。
    • TPU_TOPOLOGY:指示所需的 TPU 切片物理拓扑。拓扑格式取决于您使用的 TPU 版本:
      • 对于 TPU v4:使用 3 元组 ({A}x{B}x{C}) 定义拓扑,例如 4x4x4
      • 对于 TPU v5e:使用 2 元组 ({A}x{B}) 定义拓扑,例如 2x2

    (可选)您还可以使用以下变量:

    • RESERVATION_NAME:如果您使用 TPU 预留,则这是创建节点池时使用的预留资源的标签列表。如需详细了解如何填充 reservation_affinity 字段中的 RESERVATION_LABEL_VALUES,请参阅 Terraform 提供程序
    • autoscaling:创建启用了自动扩缩功能的节点池。当 GKE 扩缩多主机 TPU 切片节点池时,它会以原子方式将节点池从零扩容到大小上限。
      • MAX_NODES:节点池的大小上限。该值必须等于 TPU_TOPOLOGY ({A}x{B}x{C}) 中定义的值的乘积除以每个虚拟机中的芯片数量。
    • spot:可让节点池对 TPU 切片节点使用 Spot 虚拟机。创建节点池后,此设置将无法更改。如需了解详情,请参阅 Spot 虚拟机

控制台

如需创建具有 TPU 的节点池,请执行以下操作:

  1. 转到 Google Cloud 控制台中的 Google Kubernetes Engine 页面。

    前往 Google Kubernetes Engine

  2. 在集群列表中,点击您要修改的集群的名称。

  3. 点击 添加节点池

  4. 节点池详情部分中,勾选指定节点位置复选框。

  5. 根据您要使用的 TPU 版本选择可用区:

    • 对于 TPU v4,请使用 us-central2-b
    • ct5l- 开头的 TPU v5e 机器类型绝不会是多主机。
    • 对于以 ct5lp- 开头的 TPU v5e 机器类型,请使用 us-west1-cus-west4-aus-west4-bus-central1-aus-east1-cus-east5-beurope-west4-a
    • 对于以 ct5p- 开头的 TPU v5p 机器类型,请使用 us-east1-dus-east5-aus-east5-c
  6. 在导航窗格中,点击节点

  7. 机器配置部分中,选择 TPU

  8. 系列下拉菜单中,选择以下选项之一:

    • CT4P:适用于 TPU v4。
    • CT5LP:适用于 TPU v5e。
  9. 机器类型下拉菜单中,选择要用于节点的机器的名称。使用 TPU 配置映射表可了解如何定义用于创建多主机 TPU 切片节点池的机器类型和 TPU 拓扑。

  10. TPU 拓扑下拉菜单中,选择 TPU 切片的物理拓扑。

  11. 需要更改对话框中,点击进行更改

  12. 确保启动磁盘类型标准永久性磁盘SSD 永久性磁盘

  13. (可选)选中在 Spot 虚拟机上启用节点复选框,以对节点池中的节点使用 Spot 虚拟机。

  14. 点击创建

验证节点池状态

  1. 获取凭据,以便使用 kubectl 访问集群:

    gcloud container clusters get-credentials CLUSTER_NAME \
        --project=PROJECT_ID
    

    替换以下内容:

    • CLUSTER_NAME:集群的名称。
    • PROJECT_ID:您的项目 ID。
  2. 在 Cloud Shell 中使用 kubectl 查看 TPU 切片节点:

    kubectl get nodes -l cloud.google.com/gke-tpu-accelerator=TPU_ACCELERATOR \
       -l cloud.google.com/gke-tpu-topology=TPU_TOPOLOGY
    

    替换以下内容:

    • TPU_ACCELERATOR:创建节点池时使用的 TPU 加速器的类型。例如 tpu-v4-podslicetpu-v5-lite-devicetpu-v5-lite-podslice
    • TPU_TOPOLOGY:TPU 切片的物理拓扑

    输出类似于以下内容:

     NAME                                    STATUS   ROLES    AGE    VERSION
     gke-tpu-20ee2cce-5tv6                   Ready    <none>   34h     v1.28.1-gke.1066000
    

运行多切片工作负载

在本部分中,您将运行一个 JAX 工作负载(其中显示了 TPU 切片中的全球 TPU 芯片数量),然后退出。

如需运行 JAX 工作负载,请执行以下操作:

  1. 创建以下 tpu-multislice.yaml 清单:

    Autopilot

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-job
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: NUM_SLICES
          template:
            spec:
              parallelism: NUM_NODES
              completions: NUM_NODES
              backoffLimit: 0
              template:
                spec:
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: ACCELERATOR_TYPE
                    cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    - containerPort: 8431
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                     limits:
                        google.com/tpu: NUM_CHIPS
    

    标准

    apiVersion: jobset.x-k8s.io/v1alpha2
    kind: JobSet
    metadata:
      name: multislice-job
      annotations:
        alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
    spec:
      failurePolicy:
        maxRestarts: 4
      replicatedJobs:
        - name: slice
          replicas: NUM_SLICES
          template:
            spec:
              parallelism: NUM_NODES
              completions: NUM_NODES
              backoffLimit: 0
              template:
                spec:
                  hostNetwork: true
                  dnsPolicy: ClusterFirstWithHostNet
                  nodeSelector:
                    cloud.google.com/gke-tpu-accelerator: ACCELERATOR_TYPE
                    cloud.google.com/gke-tpu-topology: TPU_TOPOLOGY
                  containers:
                  - name: jax-tpu
                    image: python:3.8
                    ports:
                    - containerPort: 8471
                    - containerPort: 8080
                    - containerPort: 8431
                    securityContext:
                      privileged: true
                    command:
                    - bash
                    - -c
                    - |
                      pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                      python -c 'import jax; print("Global device count:", jax.device_count())'
                      sleep 60
                    resources:
                      limits:
                       google.com/tpu: NUM_CHIPS
    

    替换以下内容:

    • NUM_SLICES:TPU 切片节点池的数量。在这种情况下,NUM_SLICES 等于 3
    • ACCELERATOR_TYPE:创建节点池时使用的 TPU 加速器的类型。例如 tpu-v4-podslicetpu-v5-lite-devicetpu-v5-lite-podslice
    • TPU_TOPOLOGY:TPU 切片的物理拓扑。例如 4x4x42x2,具体取决于 TPU 版本。
    • NUM_NODES:节点池中的节点数。 值必须为零或 TPU_TOPOLOGY ({A}x{B}x{C}) 中定义的值的乘积除以每个虚拟机中的 TPU 芯片数量。对于多主机 TPU v4,每个虚拟机中的 TPU 芯片数量为 4。对于多主机 TPU v5e,每个虚拟机中的 TPU 芯片数量为 1、4 或 8。因此,如果 TPU_TOPOLOGY2x4x4(每个虚拟机中具有四个 TPU 芯片的 TPU v4),则 NUM_NODES 为 32/4,等于 8。
    • NUM_CHIPS:对于多主机 TPU v4,每个虚拟机中的 TPU 芯片数量为 4。对于多主机 TPU v5e,每个虚拟机中的 TPU 芯片数量为 1、4 或 8。如需了解详情,请参阅 TPU 切片中虚拟机上的 TPU 芯片

    在此清单中:

    • JobSet 是与 JobSet 名称同名的无头服务,在本例中为 multislice-job
    • alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool 注解用于配置 Pod 亲和性,以确保所有 Pod 调度到同一切片上
    • maxRestarts: 4 表示在子作业失败时 GKE 重启 JobSet 的最大次数。如果 JobSet 重启达到定义的最大值,则会将 JobSet 标记为失败。
    • parallelismcompletions 字段等于每个节点池中的节点数。
    • backoff 为 0,因为多切片仅支持同步多控制器训练。必须设置为 0。当任何 pod 失败时,使作业失败。
    • 亲和性部分的值可确保在一组多切片中只有一个 TPU 多切片工作负载运行。
    • containerPort: 8080 是 MXLA 协调器的端口
    • containerPort: 8431 是用于导出 TPU 用量指标的端口
    • securityContext: privileged: true 表示节点已启用特权模式以访问 TPU。GKE 1.28 或更高版本中的节点无需启用特权模式即可访问 TPU。如需了解详情,请参阅在不使用特权模式的情况下运行容器
  2. 应用清单:

    kubectl apply -f tpu-multislice.yaml
    
  3. 确认工作负载已被允许:

    kubectl get jobsets
    

    输出类似于以下内容:

    NAME            RESTARTS   COMPLETED   AGE
    multislice-job                         3s
    
  4. 监控预配 Pod 的状态:

    kubectl get pods
    

    输出类似于以下内容:

     NAME                                READY   STATUS      RESTARTS   AGE
     multislice-job-slice-0-0-wzq9t      0/1     Completed   0          2m31s
     multislice-job-slice-0-1-zf4dp      0/1     Completed   0          2m30s
     multislice-job-slice-1-0-hbfn5      0/1     Completed   0          2m31s
     multislice-job-slice-1-1-45fgl      0/1     Completed   0          2m30s
     multislice-job-slice-2-0-wjbp4      0/1     Completed   0          2m30s
     multislice-job-slice-2-1-lwnvs      0/1     Completed   0          2m30s
    

multislice-job JobSet 会安排、创建 Pod,然后运行 Pod 以完成操作。Pod 名称的格式为 <jobsetName>-<jobName>-<jobReplicaIndex>-<randomSuffix>jobsetName 前缀决定了 Pod 所属的 JobSet。

其他配置

以下各部分介绍了可应用于多切片的其他配置。

在 GKE Standard Pod 上启用 hostNetwork

为了提升 TPU 切片之间的网络性能,我们建议您开启 hostNetworking。在 Pod 规范中使用 hostNetwork: true 跳过所有 Kubernetes 网络堆栈,让 Kubernetes Pod 直接使用主机网络进行虚拟机之间的通信。

如需启用 hostNetworking,请从 Pod 规范中移除以下两行:

hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet

如需继续使用 podHostnames 通过 hostNetwork 发现工作器节点,请设置 dnsPolicy: ClusterFirstWithHostNet。如果您要运行自动恢复训练作业,并且需要使用相同的名称来重新加载相同的检查点,则这一点很重要。

日志记录

如果您已在集群中启用 GKE 系统日志记录,则由 GKE 节点(包括 TPU 切片节点)上运行的容器发出的日志会显示在 Logs Explorer 中。

您可以使用 Logs Explorer 通过以下过滤条件来查看工作负载的容器日志,以此查看 GKE 中的日志

resource.type="k8s_container"
resource.labels.cluster_name=CLUSTER_NAME
labels."k8s-pod/jobset_sigs_k8s_io/jobset-name"=JOBSET_NAME

对 TPU 切片和工作器使用以下过滤条件:

resource.type="k8s_container"
resource.labels.cluster_name=CLUSTER_NAME
labels."k8s-pod/jobset_sigs_k8s_io/jobset-name"=JOBSET_NAME
resource.labels.pod_name:<jobSetName>-<replicateJobName>-<job-index>-<worker-index>

如需了解详情,请参阅查看 GKE TPU 日志

可观测性和指标

除了一般的 TPU 指标之外,还有 4 个多切片专用 TPU 运行时指标。GKE 1.29.1-gke.1016000 版或更高版本提供了这些指标。TPU 工作负载必须使用 JAX 0.4.24

以下是可用的多切片指标:

  • DCN(数据中心网络)传输延迟时间:多切片流量的网络传输延迟时间分布。
  • 总体延迟时间:多切片流量的端到端总体延迟时间分布。
  • 主机到设备传输延迟时间:多切片流量的每个数据块的主机到设备传输延迟时间分布。
  • 设备到主机传输延迟时间:多切片流量的每个数据块的设备到主机传输延迟时间分布。

这些指标位于 Kubernetes 容器 (k8s_container) 架构中:

  • kubernetes.io/container/multislice/network/dcn_transfer_latencies
  • kubernetes.io/container/multislice/network/collective_end_to_end_latencies
  • kubernetes.io/container/multislice/accelerator/host_to_device_transfer_latencies
  • kubernetes.io/container/multislice/accelerator/device_to_host_transfer_latencies

TPU 切片与多切片

下表区分了 TPU 切片和多切片的架构组织:

TPU 切片 多切片
互连 工作负载在单个 TPU 切片上运行。切片中的所有 TPU 芯片与 ICI 连接。 工作负载在多个 TPU 切片上运行。切片中的通信通过 ICI 进行。切片之间的通信通过 DCN 进行。
支持的节点池 单主机 TPU 切片和多主机 TPU 切片 多主机 TPU 切片组
推荐的工作负载类型 IndexedJob 或 JobSet JobSet

后续步骤