Cloud TPU 多切片概览

Cloud TPU Multislice 是一种全栈性能伸缩技术,它使训练作业能够通过简单的数据并行处理在单个 Pod 内或多个 Pod 中的切片上使用多个 TPU 切片。对于 TPU v4 芯片,这意味着训练作业在一次运行中可以使用超过 4096 个芯片。对于需要少于 4096 个芯片的训练作业,单个切片可以实现最佳性能。但是,多个较小的切片更容易使用,因此当 MultiSlice 与较小的切片一起使用时,启动时间会更短。

多个 Slice 可线性扩缩性能

在多切片配置中部署时,每个切片中的 TPU 芯片通过芯片间互连 (ICI) 进行通信。不同切片中的 TPU 芯片通过将数据传输到 CPU(主机),而主机又通过数据中心网络 (DCN) 传输数据。

多切片数据流

开发者无需编写代码来实现切片间 DCN 通信。XLA 编译器会为您生成该代码,并将通信与计算重叠,以实现最佳性能。

概念

加速器类型
构成 MultiSlice 的每个 TPU 切片的形状。多切片请求中的每个切片都属于同一加速器类型。加速器类型由 TPU 类型(v4 或 v5e)后跟 TensorCore 数量组成。例如,v4-128 指定具有 128 个 TensorCore 的 TPU v4。
自动修复
当某个切片遇到维护事件、抢占或硬件故障时,Cloud TPU 将创建新切片。在极少数情况下,当资源不足以创建新切片时,需要等到硬件可用时才能完成创建。创建新切片后,Multislice 环境中的所有其他切片都将重启,以便继续训练。正确配置的启动脚本后,训练脚本可在无需用户干预的情况下自动重新启动,并从最新的检查点加载和恢复。
数据集
模型用于训练或推断的数据。
数据中心网络 (DCN)
延迟时间较长、吞吐量较低的网络(与 ICI 相比),它连接了多切片配置中的 TPU 切片。
群组日程
同时预配所有 TPU 切片时,保证成功预配所有切片或一个也不成功预配。
主机
主机是运行虚拟机的物理计算机。一个主机一次最多可以运行四个虚拟机。每个虚拟机都有一个专用 TPU。
推断
将预训练的机器学习模型加载到主机上,并对数据进行预测。
芯片间互连 (ICI)
高速、低延迟的内部链路,可在 TPU Pod 内连接 TPU。
多切片
两个或更多个可通过 DCN 进行通信的 TPU 芯片切片。
节点
在多切片上下文中,节点是指单个 TPU 切片。MultiSlice 中的每个 TPU 切片都会获得一个节点 ID。
Pod
通过专用 ICI 网络接口连接的 TPU 芯片的集合。借助 Pod,您可以将处理负载分布到多个 TPU 中。
已加入队列的资源 (QR)
TPU 资源的表示法,用于将针对单切片或多切片 TPU 环境的请求加入队列和管理请求。
启动脚本
每次启动或重新启动虚拟机时运行的标准 Compute Engine 启动脚本。对于多切片,它在二维码创建请求中指定。如需详细了解 Cloud TPU 启动脚本,请参阅管理 TPU 资源
TPU 切片
由 TPU 芯片组成的 TPU Pod 的逻辑子部分。切片中的所有芯片均使用 ICI 网络相互通信。
TPU 虚拟机
运行 Linux 的虚拟机,能够访问底层 TPU。对于 v4 TPU,每个 TPU 虚拟机可以直接连接到四个芯片。有时,我们将 TPU 虚拟机称为“工作器”
张量
在机器学习模型中用于表示多维数据的数据结构。
张量处理单元 (TPU)
Google 内部开发的机器学习加速芯片。它们旨在为矩阵乘法等关键机器学习任务提供快速、高能效的计算。
Cloud TPU 容量的类型

您可以通过以下三种类型的容量创建 TPU(请参阅 TPU 定价方式中的“使用选项”):

  • 预留:定位预留的配额。如需使用预留配额,您必须与 Google 签订预留协议。请在创建资源时使用 --reserved 标志。
  • 抢占式:定位抢占式配额。您的资源可能会被抢占,以便为优先级较高的作业请求腾出空间。请在创建资源时使用 --best-effort 标志。
  • 按需:定位不需要预留且不会被抢占的按需配额。TPU 请求将加入 Cloud TPU 提供的按需配额队列,资源的可用性无法保证。默认处于选中状态,无需标记。

开始使用

如果您之前未使用过 TPU,请先安装 Google Cloud CLI,然后设置您的 Cloud TPU 环境。如需使用多切片,您必须将 TPU 资源作为已加入队列的资源进行管理。

如果您已经是 TPU v4 用户并且拥有预留,则可能需要将预留迁移到新的预留系统。如需了解详情,请与您的 Google Cloud 客户代表联系。

入门示例

本教程使用 MaxText GitHub 代码库中的代码。MaxText 是一个使用 Python 和 Jax 编写且经过良好测试的基本 LLM,性能出色、可任意伸缩、开源且经过测试。MaxText 的设计宗旨是在 Cloud TPU 上高效训练。

shardings.py 中的代码旨在帮助您开始使用不同的并行处理选项进行实验。例如,数据并行、完全分片数据并行 (FSDP) 和张量并行处理。代码可以从单个切片扩展到多切片环境。

ICI 并行处理

ICI 是指将多个 TPU 连接在一个切片中的高速互连。ICI 分片对应于切片中的分片。shardings.py 提供了三个 ICI 并行参数:

  • ici_data_parallelism
  • ici_fsdp_parallelism
  • ici_tensor_parallelism

您为这些参数指定的值决定了每种并行方法的分片数。

必须对这些输入进行约束,使 ici_data_parallelism * ici_fsdp_parallelism * ici_tensor_parallelism 等于 Slice 中的条状标签数量。

下表显示了 v4-8 中提供的四个芯片的 ICI 并行处理示例用户输入:

ici_data_parallelism ici_fsdp_parallelism ici_tensor_parallelism
4 路 FSDP 1 4 1
四向 Tensor 并行处理 1 1 4
双向 FSDP + 双向 Tensor 并行处理 1 2 2

请注意,在大多数情况下,ici_data_parallelism 应保留为 1,因为 II 网络的速度足够快,几乎总是首选 FSDP 而不是数据并行。

此示例假定您熟悉如何在单个 TPU 切片上运行代码,例如使用 JAX 在 Cloud TPU 虚拟机上运行计算。以下示例展示了如何对单个 Slice 运行 shardings.py

  1. 设置环境:

    $ gcloud auth login
    $ gcloud config set project your-project-id
    $ gcloud config set compute/zone your-zone
    
  2. gcloud 创建 SSH 密钥。建议留空(运行以下命令后按 Enter 键两次)。如果系统提示 google_compute_engine 文件已存在,请替换现有版本。

    $ ssh-keygen -f ~/.ssh/google_compute_engine
    
  3. 使用以下命令预配 TPU:

    $ gcloud alpha compute tpus queued-resources \
    create your-qr-id \
    --accelerator-type your-accelerator-type \
    --runtime-version tpu-ubuntu2204-base \
    --node-id qr-id \
    [--reserved |--best-effort]
    

    命令标志说明

    your-qr-id
    用户定义的字符串,用于标识二维码请求。
    accelerator-type
    加速器类型指定要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    runtime-version
    [Cloud TPU 软件版本](/tpu/docs/supported-tpu-configurations#tpu_software_versions)。
    node-id
    为了响应二维码请求而创建的 TPU 资源的 ID。
    reserved
    创建切片时,请使用预留配额。
    best-effort
    尽力配额创建切片 [默认]。

    Google Cloud CLI 仅支持部分创建二维码选项,例如标记。如需了解详情,请参阅创建二维码

  4. 等待二维码处于 ACTIVE 状态,这意味着工作器节点处于 READY 状态。二维码配置开始后,可能需要一到五分钟才能完成,具体取决于二维码的大小。您可以使用以下命令检查二维码请求的状态:

    $ gcloud alpha compute tpus queued-resources \
      list --filter=your-qr-id
    
  5. 一个 v4-8 切片具有一个 TPU 虚拟机。使用 SSH 连接到 TPU 虚拟机:

    $ gcloud compute tpus tpu-vm ssh your-qr-id
    
  6. 将 MaxText(包含 shardings.py)克隆到 TPU 虚拟机。

  7. 在 MaxText 代码库目录中,运行设置脚本,在 TPU 切片上安装 JAX 和其他依赖项。设置脚本需要几分钟才能运行完成。

    $ bash setup.sh
    
  8. 运行以下命令,以在 TPU 切片上运行 shardings.py

    $ python3 pedagogical_examples/shardings.py \
      --ici_fsdp_parallelism 4 \
      --batch_size 131072 \
      --embedding_dimension 2048
    

    您可以在日志中查看结果。您的 TPU 应该能够实现约每秒 260 TFLOP 或令人惊叹的 FLOP 利用率高达 90%以上!在本例中,我们选择了适合 TPU 的高带宽内存 (HBM) 的最大批量。

  9. 您可以随意探索基于 ICI 的其他分片策略,例如,您可以尝试以下组合:

    $ python3 pedagogical_examples/shardings.py \
      --ici_tensor_parallelism 4 \
      --batch_size 131072 \
      --embedding_dimension 2048
    
  10. 完成后,删除二维码和 TPU 切片。您应该在设置切片的环境中运行这些清理步骤(首先运行 exit 以退出 SSH 会话)。删除操作将需要两到五分钟才能完成,并且可以使用可选的 --async 标志在后台运行。

    $ gcloud alpha compute tpus queued-resources
      delete your-qr-id --force (--async)
    

使用 DCN 并行式进行多切片分片

shardings.py 脚本使用三个参数来指定 DCN 并行处理,对应于每种并行数据的分片数:

  • dcn_data_parallelism
  • dcn_fsdp_parallelism
  • dcn_tensor_parallelism

必须限制这些参数的值,以使 dcn_data_parallelism * dcn_fsdp_parallelism * dcn_tensor_parallelism 等于切片的数量。

例如,对于两个 Slice,请使用 --dcn_data_parallelism = 2

dcn_data_parallelism dcn_fsdp_parallelism dcn_tensor_parallelism 切片数量
双向数据并行处理 2 1 1 2

dcn_tensor_parallelism 应始终设置为 1,因为 DCN 不适合此类分片。对于 v4 芯片上的典型 LLM 工作负载,dcn_fsdp_parallelism 也应设置为 1,因此 dcn_data_parallelism 也应设置为 Slice 数量,但此设置取决于应用。

随着切片数量的增加(假设切片大小和每个切片的批次保持不变),并行处理的数据量也会增加。

在 MultiSlice 环境中运行 shardings.py

您可以使用 multihost_runner.py 在多切片环境中运行 shardings.py,也可以通过对每个 TPU 虚拟机运行 shardings.py 来运行。在这里,我们使用 multihost_runner.py。以下步骤与 MaxText 代码库中的使用入门:多个切片的快速实验非常相似,只不过在这里,我们运行的是 shardings.py,而不是 train.py 中更复杂的 LLM。

multihost_runner.py 工具针对快速实验进行了优化,可反复重复使用相同的 TPU。由于 multihost_runner.py 脚本依赖于长期有效的 SSH 连接,因此我们不建议将它用于任何长时间运行的作业。如果要运行更长时间的作业(例如几小时或几天),我们建议您使用 multihost_job.py

在本教程中,我们使用“runner”一词表示运行 multihost_runner.py 脚本的机器。runner我们使用术语“工作器”来表示构成切片的 TPU 虚拟机。您可以在本地机器或切片所属项目的任何 Compute Engine 虚拟机上运行 multihost_runner.py。不支持在工作器上运行 multihost_runner.py

multihost_runner.py 会使用 SSH 自动连接到 TPU 工作器。

在此示例中,我们在两个 v4-16 切片(总共 4 个虚拟机和 16 个 TPU 芯片)上运行 shardings.py。您可以修改此示例,使其在更多 TPU 上运行。

设置您的环境

  1. 在运行程序机器上克隆 MaxText

  2. 转到代码库目录。

  3. gcloud 创建 SSH 密钥,我们建议留空密码(运行以下命令后按两次 Enter 键)。如果系统提示 google_compute_engine 文件已存在,请选择不保留现有版本。

      $ ssh-keygen -f ~/.ssh/google_compute_engine
      

  4. 添加一个环境变量,将 TPU 切片数量设置为 2

      $ export SLICE_COUNT=2
      

  5. 使用 queued-resources create 创建多切片环境。

    以下命令展示了如何创建 v4 Multislice TPU。如需使用 v5e,请指定 v5e accelerator-type(例如 v5litepod-16)和 v5e runtime-version (v2-alpha-tpuv5-lite)。

      $ gcloud alpha compute tpus queued-resources 
    create your-qr-id
    --accelerator-type=your-accelerator-type
    --runtime-version=tpu-vm-runtime-version
    --node-count=node-count
    --node-prefix=your-qr-id
    [--reserved|--best-effort]

    命令标志说明

    your-qr-id
    用户定义的字符串,用于标识二维码请求。
    accelerator-type
    加速器类型指定要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    runtime-version
    Cloud TPU 软件版本
    node-count
    要创建的 Slice 的数量。
    node-prefix
    用于为每个切片生成名称的前缀。每个切片的前缀后面会附加一个数字。例如,如果将 node-prefix 设置为 mySlice,则切片的名称为:mySlice-0mySlice-1,依此类推。
    reserved
    创建切片时,请使用预留配额。
    best-effort
    尽力配额创建切片 [默认]。

  6. 二维码配置开始后,最多可能需要五分钟才能完成,具体取决于二维码的大小。等待排队资源 (QR) 处于 ACTIVE 状态。您可以使用以下命令检查二维码请求的状态:

    $ gcloud alpha compute tpus queued-resources list \
    --filter=your-qr-id
    

    此命令应生成如下输出:

    NAME        ZONE           NODE_COUNT  ACCELERATOR_TYPE  STATE
    ...
    que-res-id  us-central2-b  4           v4-16             ACTIVE
    ...
    

    如果二维码处于 WAITING_FOR_RESOURCESPROVISIONING 状态的时间超过 15 分钟,请与您的 Google Cloud 客户代表联系。

  7. 安装依赖项。

    $ python3 multihost_runner.py \
      --TPU_PREFIX=your-qr-id \
      --COMMAND="bash setup.sh"
    
  8. 使用 multihost_runner.py 对每个工作器运行 shardings.py

    $ python3 multihost_runner.py \
      --TPU_PREFIX=your-qr-id \
      --COMMAND="python3 pedagogical_examples/shardings.py \
      --dcn_data_parallelism $SLICE_COUNT \
      --ici_fsdp_parallelism 8 \
      --batch_size 131072 \
      --embedding_dimension 2048"
    

    您会在日志文件中看到每秒大约 230 TFLOP 的性能。

  9. 完成后,请清理 TPU 和二维码。删除操作需要两到五分钟才能完成,并且可以使用可选的 --async 标志在后台运行。

将工作负载扩缩为多切片

在多切片环境中运行模型之前,请进行以下代码更改:

改用 MultiSlice 时,只需进行以下代码更改即可。为了实现高性能,DCN 需要映射到并行、完全分片的数据并行或流水线并行轴。利用多切片进行分片以最大限度提高性能中更详细地讨论了性能注意事项和分片策略。

如需验证您的代码是否可以访问所有设备,您可以断言 len(jax.devices()) 等于多切片环境中的芯片数量。例如,如果您使用的是四个 v4-16 Slice,则每个 Slice 有 8 个条状标签 * 4 个,因此 len(jax.devices()) 应返回 32。

为多切片环境选择切片大小

如需线性加速,请添加与现有切片大小相同的新切片。例如,如果使用 v4-512 切片,Multislice 会添加第二个 v4-512 切片并将全局批次大小加倍,从而实现大约两倍的性能。如需了解详情,请参阅利用多切片进行分片以最大限度提高性能

在多个 Slice 上运行作业

您可以通过以下三种不同的方法在多切片环境中运行自定义工作负载:

  1. 使用实验运行程序脚本 multihost_runner.py
  2. 使用生产运行程序脚本 multihost_job.py
  3. 使用人工方法

实验运行程序脚本

multihost_runner.py 脚本会将代码分发到现有的 MultiSlice 环境,并在每个主机上运行您的命令,将日志复制回原语言,并跟踪每个命令的错误状态。MaxText 自述文件中介绍了 multihost_runner.py 脚本。

由于 multihost_runner.py 会维持持久性 SSH 连接,因此它仅适合规模较小、运行时间相对较短的实验。您可以根据工作负载和硬件配置调整 multihost_runner.py 教程中的步骤。

生产环境运行程序脚本

对于需要弹性以应对硬件故障和其他抢占的生产作业,最好直接集成 Create Queued Resource API。作为可正常工作的示例,我们提供了 multihost_job.py,它可使用适当的启动脚本触发 Created Queued Resource API 调用,以运行训练并在抢占后恢复资源。MaxText 自述文件中介绍了 multihost_job.py 脚本。

由于 multihost_job.py 必须为每次运行预配资源,因此它提供的迭代周期不如 multihost_runner.py 快。

人工方法

我们建议您使用或调整 multihost_runner.pymultihost_job.py 在多切片配置中运行自定义工作负载。但是,如果您希望直接使用二维码命令预配和管理您的环境,请参阅管理多切片环境

管理多切片环境

如需在不使用 MaxText 代码库中提供的工具的情况下手动配置和管理二维码,请阅读以下部分。

创建二维码

在预配容量之前设置以下环境变量:

  $ export your-qr-id=your-queued-resource-id
  $ export PROJECT=your-project-name
  $ export ZONE=us-central2-b
  $ export NETWORK_NAME=your-network-name
  $ export SUBNETWORK_NAME=your-subnetwork-name
  $ export RUNTIME_VERSION=tpu-ubuntu2204-base
  $ export ACCELERATOR_TYPE=v4-16
  $ export SLICE_COUNT=4
  $ export STARTUP_SCRIPT="#!/bin/bash\n ..."
  $ gcloud config set project project-name
  $ gcloud config set compute/zone zone
输入 说明
your-qr-id 用户指定的二维码 ID。
项目 Google Cloud 项目名称
区间 us-central2-b
NETWORK_NAME VPC 网络的名称。
SUBNETWORK_NAME VPC 网络中子网的名称
RUNTIME_VERSION tpu-ubuntu2204-base
ACCELERATOR_TYPE v4-16
EXAMPLE_TAG_1、EXAMPLE_TAG_2... 用于标识网络防火墙的有效来源或目标的标记
SLICE_COUNT 切片数量。上限为 256 个切片。
STARTUP_SCRIPT 如果将 启动脚本添加到创建请求中,则每当预配或重启 TPU 切片以及修复或重置 TPU 切片时,该脚本都可以运行。

使用 gcloud 创建二维码请求

$ gcloud alpha compute tpus queued-resources \
  create ${your-qr-id} \
  --project your-project-id \
  --zone your-zone \
  --node-count ${SLICE_COUNT} \
  --accelerator-type ${ACCELERATOR_TYPE} \
  --runtime-version ${RUNTIME_VERSION} \
  --network ${NETWORK_NAME} \
  --subnetwork ${SUBNETWORK_NAME} \
  --tags ${EXAMPLE_TAG_1},${EXAMPLE_TAG_2} \ --metadata=startup-script='${STARTUP_SCRIPT}'
  [--reserved|--best-effort]
  

命令标志说明

your-qr-id
用户定义的字符串,用于标识二维码请求。
project
用户定义的字符串,用于标识二维码请求。
zone
要在其中创建二维码的 Google Cloud 区域。
node-count
要创建的 Slice 的数量。
accelerator-type
加速器类型指定要创建的 Cloud TPU 的版本和大小。如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
runtime-version
Cloud TPU 软件版本
network
要连接到 TPU 资源的 VPC 网络的名称。
subnetwork
要连接到 TPU 资源的 VPC 子网的名称。
reserved
创建切片时,请使用预留配额。
best-effort
尽力配额创建切片 [默认]。

在选择 --reserved--best_effort 或默认的按需配额之前,请确保您拥有相应的配额。如需了解配额类型,请参阅配额政策

使用 curl 创建二维码请求

创建一个名为 queued-resource-req.json 的文件,并将以下 JSON 复制到其中。

{
  "guaranteed": { "reserved": true },
  "tpu": {
    "node_spec": [
    {
      "parent": "projects/your-project-number/locations/your-zone",
        "node": {
          "accelerator_type": "accelerator-type",
          "runtime_version": "tpu-vm-runtime-version",
          "network_config": {
            "network": "your-network-name",
            "subnetwork": "your-subnetwork-name",
            "enable_external_ips": true
          },
          "tags" : ["example-tag-1"]
          "metadata": {
            "startup-script": "your-startup-script"
          }
      },
      "multi_node_params": {
        "node_count": slice-count,
        "node_id_prefix": "your-queued-resource-id"
      }
    }
    ]
  }
}
  • your-project-number - 您的 Google Cloud 项目编号
  • your-zone - 您要在其中创建二维码的区域
  • accelerator-type - 单个切片的版本和大小
  • tpu-vm-runtime-version - TPU 虚拟机运行时版本
  • your-network-name - 可选,将附加二维码的网络
  • your-subnetwork-name - 可选,附加二维码的子网
  • example-tag-1 - 可选,任意标记字符串
  • your-startup-script - 将在分配二维码时运行的启动脚本
  • slice-count - 多切片环境中的 TPU 切片的数量
  • your-qr-id - 用户为二维码提供的 ID

如需了解详情,请参阅 REST Queued Resource API 文档,了解所有可用选项。

如需使用抢占式容量,请执行以下命令:

"guaranteed": { "reserved": true } - "best_effort": {}

或者移除该行以使用默认的按需容量。

提交包含 JSON 载荷的二维码创建请求:

  $ curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" -d @queuedresourcereq.json https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/queuedResources\?queued_resource_id\=your-qr-id
  • your-project-id - 您的 Google Cloud 项目 ID
  • your-zone - 您要在其中创建二维码的区域
  • your-qr-id - 用户为二维码提供的 ID

响应应如下所示:

{
  "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qr-guid>",
  "metadata": {
    "@type": "type.googleapis.com/google.cloud.common.OperationMetadata",
    "createTime": "2023-11-01T00:17:05.742546311Z",
    "target": "projects/<your-project-id>/locations/<your-zone>/queuedResources/<your-qa-id>",
    "verb": "create",
    "cancelRequested": false,
    "apiVersion": "v2alpha1"
  },
  "done": false
}

使用 name 属性的字符串值末尾的 GUID 值来获取有关二维码请求的信息。

检索二维码的状态

如需获取二维码请求的状态,请使用以下命令:

  $ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/operations/operation-your-qr-guid
  • your-project-id - 您的 Google Cloud 项目 ID
  • your-zone - 要创建二维码的区域
  • your-qr-guid - QR 创建请求输出中 name 后面的 GUID。

此命令的响应包含操作的状态:

{
  "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qa-guid>,
  "metadata": {...},
  "done": true,
  "response": {
    "@type": "type.googleapis.com/google.cloud.tpu.v2alpha1.QueuedResource",
    ...
    "state": {
      "state": "WAITING_FOR_RESOURCES"
    }
  }
}

如果二维码成功创建 ("done = true"),则 response 字段中的状态将为 WAITING_FOR_RESOURCESFAILED。如果二维码处于 WAITING_FOR_RESOURCES 状态,则表示该二维码已加入队列,并将在有足够的资源后开始配置。如果二维码处于 FAILED 状态,则输出中会显示失败原因。如需详细了解其他可能的状态,请参阅已加入队列的资源用户指南

操作完成后,请使用描述二维码来监控二维码的各个阶段。

在极少数情况下,您可能会发现二维码处于 FAILED 状态,而某些 Slice 是 ACTIVE 状态。如果发生这种情况,请删除已创建的资源,然后过几分钟再试,或联系 Cloud TPU 团队解决此问题。

SSH 并安装依赖项

在 TPU Pod 切片上运行 JAX 代码介绍了如何在单个切片中使用 SSH 连接到 TPU 虚拟机。如需通过 SSH 连接到多切片环境中的所有 TPU 虚拟机并安装依赖项,请使用以下 gcloud 命令:

  $ gcloud compute tpus queued-resources ssh ${your-qr-id} \
    --zone your-zone \
    --node=all \
    --worker=all \
    --command="command-to-run"
    --batch-size=4

gcloud 命令使用 SSH 将指定命令发送到 QR 中的所有工作器和节点。该命令将分为四组,然后同时发送。当前批次完成执行后,系统会发送下一批命令。如果其中某个命令失败,处理会停止,并且不再发送批次。如需了解详情,请参阅队列资源 API 参考文档。如果您使用的切片数量超出本地计算机的线程处理限制(也称为批处理限制),您将遇到死锁。例如,假设本地机器上的批处理限制为 64。如果您尝试对超过 64 个切片(例如 100 个)运行训练脚本,则 SSH 命令会将这些切片拆分为多个批次。它将对第一批 64 个切片运行训练脚本,并等待脚本完成后再对其余 36 个切片运行该脚本。但是,在剩余的 36 个切片开始运行脚本之前,第一批 64 个切片无法完成,从而导致死锁。

为避免这种情况,您可以在每个虚拟机上在后台运行训练脚本,只需在您使用 --command 标志指定的脚本命令中附加和号 (&) 即可。执行此操作后,在第一批切片上启动训练脚本后,控制权将立即返回到 SSH 命令上。然后,SSH 命令可以开始对剩余 36 个切片运行训练脚本。在后台运行命令时,您需要适当地传输 stdoutstderr 流。如需在同一二维码内提高并行性,您可以使用 --node 参数选择特定的 Slice。

广告网络设置

通过执行以下步骤,确保 TPU 切片可以相互通信。在每个切片上安装 JAX。如需了解详情,请参阅在 TPU Pod 切片上运行 JAX 代码。断言 len(jax.devices()) 等于 MultiSlice 环境中的芯片数量。为此,请对每个切片运行以下命令:

  $ python3 -c 'import jax; print(jax.devices())'

如果您在 v4-16 的四个 Slice 上运行此代码,则每个 Slice 有 8 个芯片和 4 个 Slice,jax.devices() 应返回总共 32 个芯片(设备)。

列出二维码

您可以使用 queued-resources list 命令查看二维码的状态:

$ gcloud alpha compute tpus queued-resources list

NAME        ZONE           NODE_COUNT  ACCELERATOR_TYPE  STATE
...
que-res-id  us-central2-b  4           v4-16             ACTIVE
...

描述二维码

如需查看二维码的详细配置和状态,请使用描述二维码 API。您可以使用 gcloudcurl 调用此 API。

使用 gcloud

$ gcloud alpha compute tpus queued-resources describe ${your-qr-id}
...state:
 state: ACTIVE
...

使用 curl

$ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/queuedResources/${your-qr-id}
{
  "name": your-queued-res,
  "tpu": {
    "nodeSpec": [
      {
        ... // node 1
      },
      {
        ... // node 2
      },
      ...
    ]
  },
  ...
  "state": "ACTIVE"
}

state 表示二维码的状态。如需详细了解二维码的可能状态,请参阅已加入队列的资源

在预配的环境中启动作业

您可以手动运行工作负载,方法是通过 SSH 连接到每个切片中的所有主机,并在所有主机上运行以下命令。

$ gcloud compute tpus tpu-vm ssh your-qr-id \
  --zone=your-zone \
  --worker=all \
  --node=all \
  --command="command-to-run"

重置二维码

ResetQueuedResource API 可用于重置 ACTIVE 二维码中的所有虚拟机。重置虚拟机会强制清空机器内存并将虚拟机重置为初始状态。本地存储的所有数据都将保持不变,并且将在重置后调用启动脚本。如果要重启所有 TPU,ResetQueuedResource API 会很有用。例如,当训练卡住,重置所有虚拟机比调试更容易。

所有虚拟机的重置是并行执行,ResetQueuedResource 操作需要一到两分钟才能完成。如需调用该 API,请使用以下命令:

$ gcloud alpha compute tpus queued-resources reset your-qr-id

正在删除二维码

如需在训练会话结束时释放资源,请使用 --force 标志删除已加入队列的资源。删除操作将需要两到五分钟才能完成,并且可以使用可选的 --async 标志在后台运行。

$ gcloud alpha compute tpus queued-resources \
delete your-qr-id --force (--async)

自动故障恢复

如果发生中断,Multislice 会针对受影响的切片提供无干预的修复,并在之后重置所有切片。受影响的切片会被替换为新的切片,而其余健康状况良好的切片会被重置。如果没有容量可用于分配替换切片,训练会停止。

如需在中断后自动恢复训练,您必须指定一个启动脚本,用于检查并加载上次保存的检查点。每次重新分配切片或重置虚拟机时,启动脚本都会自动运行。您可以在发送到创建二维码请求 API 的 JSON 载荷中指定启动脚本。

通过以下启动脚本(在创建二维码中),您可以从失败中自动恢复,并在 MaxText 训练期间从存储在 Cloud Storage 存储桶中的检查点继续训练:

{
 "tpu": {
   "node_spec": [
     {
      ...
         "metadata": {
               "startup-script": "#! /bin/bash \n pwd \n runuser -l user1 -c 'cd /home/user1/MaxText && python3 MaxText/train.py MaxText/configs/base.yml run_name=run_test_failure_recovery dcn_data_parallelism=4 ici_fsdp_parallelism=8 steps=10000 save_period=10 base_output_directory='gs://user1-us-central2'' EOF"
         }
     ...
     }
   ]
 }
}

请先克隆 MaxText 代码库,然后再尝试此操作。

性能分析和调试

在单切片和多切片环境中,分析是一样的。如需了解详情,请参阅剖析 JAX 程序的性能

优化培训

使用 MultiSlice 进行分片以实现最佳性能

在多切片环境中实现最高性能需要考虑如何跨多个切片进行分片。通常有三种选择(数据并行处理、完全分片数据并行处理和流水线并行处理)。我们不建议跨模型维度对激活进行分片(有时称为张量并行),因为这需要太多的切片间带宽。对于所有这些策略,您可以在过去成功运作的切片中保留相同的分片策略。

我们建议从纯数据并行开始。使用完全分片数据并行处理有助于释放内存用量。缺点是切片之间的通信会使用 DCN 网络,并且会降低工作负载的速度。根据批次大小,仅在必要时使用流水线并行处理(分析如下)。

何时使用数据并行处理

如果您的工作负载运行良好,但您希望通过跨多个切片进行伸缩来提高其性能,则纯数据并行性会非常好。

如需跨多个 Slice 实现强有力的伸缩,在 DCN 上执行 all-Reduce 所需的时间必须少于执行反向传递所需的时间。DCN 用于切片之间的通信,是工作负载吞吐量的限制因素。

每个 v4 TPU 芯片的性能峰值为每秒 275 * 1012 FLOPS。

每个 TPU 主机有四个芯片,每个主机的最大网络带宽为 50 Gbps。

这意味着算术强度为 4 * 275 * 1012 FLOPS / 50 Gbps = 22000 FLOPS / 位。

您的模型将为每步的每个参数使用 32 到 64 位 DCN 带宽。如果使用两个切片,您的模型将使用 32 位 DCN 带宽。如果您使用两个以上的 Slice,编译器就会执行完整的 shuffle all-reduce 操作,并且每个步骤的每个参数最多会使用 64 位 DCN 带宽。每个参数所需的 FLOPS 量因模型而异。具体而言,对于基于 Transformer 的语言模型,前向传播和后向传播所需的 FLOPS 数量约为 6 * B * P,其中:

  • B 是词元中的批次大小
  • P 是参数的数量

每个参数的 FLOPS 数量为 6 * B,反向传递期间每个参数的 FLOPS 数量为 4 * B

为了确保跨多个切片进行强有力的伸缩,请确保操作强度超过了 TPU 硬件的算术强度。如需计算操作强度,请将后向传播期间每个参数的 FLOPS 数除以每个步每个参数的网络带宽(以位为单位):Operational Intensity = FLOPSbackwards_pass / DCN bandwidth

因此,对于基于 Transformer 的语言模型,如果您使用两个切片:Operational intensity = 4 * B / 32

如果您使用的切片超过两个:Operational intensity = 4 * B/64

这表明,对于基于 Transformer 的语言模型,批次大小下限介于 176k 到 352k 之间。由于 DCN 网络可能会短暂地丢弃数据包,因此最好保留明显的误差空间,只有在每个 Pod 的批次大小至少为 35 万(两个 Pod)到 70 万(多个 Pod)时,才并行部署数据。

对于其他模型架构,您需要估算每个 Slice 的向后传递的运行时间(通过使用性能分析器计时或计算 FLOPS)。然后,您可以将该时间与预期运行时间进行比较,以通过 DCN 减少所有数据,并大致估算数据并行性是否适合您。

何时使用完全分片数据并行 (FSDP)

完全分片数据并行 (FSDP) 可兼顾数据并行性(跨节点将数据分片)与跨节点的权重分片相结合的方式。对于前向传播和反向传递中的每个运算,系统会全部收集权重,以使每个切片都具有所需的权重。不是使用 all-reduce 同步梯度,而是在生成渐变时减小渐变效果。这样一来,每个切片仅获得其所负责的权重的梯度。

与数据并行处理类似,FSDP 需要根据切片数量线性伸缩全局批次大小。随着切片数量的增加,FSDP 会降低内存压力。这是因为每个切片的权重和优化器状态数量会减少,但代价是网络流量会增加,且因收集延迟而被屏蔽的可能性更大。

在实践中,如果要增加每个切片的批次,存储更多激活以最大限度地减少反向传递期间的重新具体化,或增加神经网络中的参数数量,那么跨切片的 FSDP 是最佳选择。

FSDP 中的全收集和全减少操作的工作方式与 DP 中的类似,因此您可以按照上一部分中所述的方式确定 FSDP 工作负载是否受 DCN 性能限制。

何时使用流水线并行处理

当使用需要全局批次大小大于首选批次大小的其他并行策略时,流水线并行性变得相关。流水线并行处理允许组成流水线的切片“共享”批次。但是,流水线并行处理有两个重大缺点:

  1. 它会产生“流水线气泡”,其中芯片因等待数据而处于空闲状态。
  2. 它需要进行微批处理,这会减少有效批次大小和算术强度,并最终对 FLOP 利用率进行建模。

仅当其他并行策略需要的全局批次大小过大时,才应使用流水线并行处理。在尝试流水线并行处理之前,有必要根据经验了解每个样本的收敛是否以实现高性能 FSDP 所需的批次大小减慢。FSDP 往往会获得更高的模型 FLOP 利用率,但如果每个样本的收敛会随着批次大小的增加而减慢,则流水线并行处理可能仍然是更好的选择。大多数工作负载都可以容忍足够大的批量,而无法从流水线并行处理中获益,但您的工作负载可能会有所不同。

如果需要并行处理流水线,我们建议将其与数据并行处理或 FSDP 结合使用。这样,您就可以最大限度地减小流水线深度,同时增加每个流水线的批次大小,直到 DCN 延迟时间不再是吞吐量的一个因素。具体而言,如果您有 N 个切片,请考虑使用深度为 2 和 N/2 个数据并行副本的流水线,然后采用深度为 4 和 N/4 的数据并行副本的流水线,以此类推,直到每个流水线的批次变得足够大,可以在向后传递时将 DCN 集合隐藏在算术后面。这样可以最大程度减少流水线并行性造成的速度缓慢,同时使您能够在超出全局批次大小限制的情况下进行扩容。

多切片最佳实践

数据加载

在训练期间,我们会反复从数据集中加载批次以馈送到模型中。拥有一个高效的异步数据加载器,将批量处理跨主机分片,对于避免 TPU 工作负载匮乏非常重要。MaxText 中的当前数据加载器会为每个主机加载相等的示例子集。此解决方案足以处理文本,但需要在模型内重新分片。此外,MaxText 尚未提供确定性快照(允许数据迭代器在抢占之前和之后加载相同的数据)。

检查点

Orbax 检查点库提供了用于将 JAX PyTree 设置到本地存储空间或 Google Cloud 存储空间的检查点基元。我们在 checkpointing.py 中的 MaxText 中提供了具有同步检查点的参考集成。

受支持的配置

形状

所有切片都必须具有相同的形状(例如,相同的 AcceleratorType)。不支持异构切片形状。

编排

GKE 支持编排。如需了解详情,请参阅 GKE 中的 TPU

框架

Multislice 仅支持 JAX 和 PyTorch 工作负载。

并行数量

我们建议用户使用数据并行测试多切片。如需详细了解如何使用 Multislice 实现流水线并行处理,请与您的 Google Cloud 客户代表联系。

支持与反馈

我们欢迎各种反馈!如需分享反馈或请求支持,请通过 Cloud TPU 支持或反馈表单与我们联系。