Cloud TPU 多切片概览

Cloud TPU Multislice 是一种全栈性能伸缩技术 可让训练作业在单个 Pod 内或单个 Pod 中 并行处理多个 Pod 中的切片。对于 TPU v4 芯片, 意味着训练作业可以在一次运行中使用超过 4096 个芯片。训练用 对于需要少于 4096 个条状标签的作业,单个切片可以提供 性能然而,多个较小的切片更容易获得, 当以更小尺寸使用多切片时,可以缩短启动时间 。

多个切片可线性扩缩性能

以多切片配置部署时,每个切片中的 TPU 芯片 通过芯片间互连 (ICI) 进行通信。TPU 芯片在 切片通过将数据传输到 CPU(主机)来进行通信 通过数据中心网络 (DCN) 传输数据。

多切片数据流

开发者无需编写代码即可实现切片间 DCN 通信。 XLA 编译器会为您生成该代码, 以实现最佳性能。

概念

加速器类型
包含一个多切片的每个 TPU 切片的形状。每个 多切片请求中的切片属于同一加速器类型。加速器 类型由 TPU 类型(v4 或 v5e)组成,后跟 TensorCores。例如,v4-128 指定具有 128 个 TensorCore 的 TPU v4。
自动修复
当切片遇到维护事件、抢占或硬件故障时, Cloud TPU 将创建一个新的切片。在极少数情况下 资源不足,无法创建新切片,创建将无法完成 直到有可用的硬件在新切片创建后,所有其他 多切片环境中的切片将重新开始, 使用正确配置的启动脚本后,训练脚本 可以自动重新启动,而无需用户干预、加载和恢复 从最新的检查点开始。
数据集
模型用于训练或推理的数据。
数据中心网络 (DCN)
延迟时间较长、吞吐量较低(与 ICI 相比) 连接多切片配置中的 TPU 切片。
帮派调度
同时对所有 TPU 切片进行预配时,可以保证 或者所有切片都未成功预配。
主机
主机是运行虚拟机的物理计算机。一个主机最多可以运行四个虚拟机 。每个虚拟机都有一个专用 TPU。
推断
将预训练的机器学习模型加载到主机上并进行预测, 数据。
芯片间互连 (ICI)
用于连接 TPU Pod 内多个 TPU 的高速、低延迟内部链路。
多切片
两个或更多可通过 DCN 通信的 TPU 芯片切片。
节点
在多切片上下文中,节点是指单个 TPU 切片。每个 系统会为多切片中的 TPU 切片指定节点 ID。
Pod
通过专用 ICI 网络接口连接的一系列 TPU 芯片。答 利用 Pod,您可以在多个 TPU 之间分配处理负载。
已加入队列的资源 (QR)
TPU 资源的表示形式,用于排队和管理对 单切片或多切片 TPU 环境。
启动脚本
标准的 Compute Engine 启动脚本 每次启动或重新启动虚拟机时都会运行这个命令。对于多切片: 是在 QR 创建请求中指定的。更多信息 请参阅管理 Cloud TPU 资源
TPU 切片
由 TPU 芯片组成的 TPU Pod 的逻辑子部分。同一类中的所有条状标签 Slice 使用 ICI 网络相互通信。
TPU 虚拟机
一台运行 Linux 并且能够访问底层 TPU 的虚拟机。对于 v4 TPU,每个 TPU 虚拟机都可以直接访问四个芯片。有时,我们称之为 TPU, 虚拟机作为工作器
Tensor
一种数据结构,用于表示机器中的多维数据 机器学习模型。
张量处理单元 (TPU)
Google 内部开发的机器学习加速芯片。它们旨在 为关键机器学习任务(例如 矩阵乘法。
Cloud TPU 容量类型

可以通过不同类型的容量创建 TPU(请参阅 TPU 定价方式):

  • 预留:目标预留配额。要使用预留的配额,您必须拥有一个 。在创建--reserved 资源
  • Spot:使用 Spot 虚拟机定位抢占式配额。您的 从而为更高优先级的请求腾出空间 优先级作业。创建资源时,请使用 --spot 标志。
  • 按需:以按需配额为目标,不需要预留 并且不会被抢占TPU 请求将加入按需队列 配额队列,资源可用性 。默认处于选中状态,无需任何标志。

开始使用

如果您之前没有使用过 TPU,请先安装 Google Cloud CLI, 以及如何设置您的 Cloud TPU 环境。要使用 多切片,您的 TPU 资源必须作为已加入队列的资源进行管理。

如果您已经是 TPU v4 用户并且有预留,则可能需要迁移 新的预订系统如需更多信息 请与您的 Google Cloud 客户代表联系。

入门示例

本教程使用 MaxText GitHub 代码库中的代码。 MaxText 性能出色、可任意伸缩、开源且经过充分测试 使用 Python 和 Jax 编写的基本 LLM。MaxText 旨在 Cloud TPU。

shardings.py 中的代码 旨在帮助您开始尝试不同的并行性 选项。例如,数据并行处理、完全分片数据并行处理 (FSDP), 和张量并行处理。代码从单个切片扩展到多切片 环境

ICI 并行处理

ICI 是指在单个 Pod 中连接多个 TPU 的高速互连。 。ICI 分片对应于切片内的分片。shardings.py 提供三个 ICI 并行参数:

  • ici_data_parallelism
  • ici_fsdp_parallelism
  • ici_tensor_parallelism

您为这些参数指定的值决定了 并行处理方法

必须对这些输入加以限制, ici_data_parallelism * ici_fsdp_parallelism * ici_tensor_parallelism等于 表示切片中条状标签的数量。

下表显示了四个示例,用于实现 ICI 并行处理 v4-8 中提供的条状标签:

ici_data_parallelism ici_fsdp_parallelism ici_tensor_parallelism
4 声道 FSDP 1 4 1
四向张量并行处理 1 1 4
双向 FSDP + 双向 Tensor 并行处理 1 2 2

请注意,在大多数情况下,ici_data_parallelism 应保留为 1,因为 ICI 网络的速度足够快,几乎总是首选 FSDP 而不是数据并行处理。

此示例假定您熟悉在单个 TPU 切片上运行代码 例如使用 JAX 在 Cloud TPU 虚拟机上运行计算。 以下示例展示了如何在单个 Slice 上运行 shardings.py

  1. 设置环境:

    $ gcloud auth login
    $ gcloud config set project your-project-id
    $ gcloud config set compute/zone your-zone
    
  2. gcloud 创建 SSH 密钥。我们建议将密码留空(按 在运行以下命令后输入两次)。如果系统提示您 google_compute_engine 文件已存在,请替换现有版本。

    $ ssh-keygen -f ~/.ssh/google_compute_engine
    
  3. 使用以下命令预配 TPU:

    $ gcloud alpha compute tpus queued-resources \
    create your-qr-id \
    --accelerator-type your-accelerator-type \
    --runtime-version tpu-ubuntu2204-base \
    --node-id qr-id \
    [--reserved |--spot]
    

    命令标志说明

    your-qr-id
    一个用户定义的字符串,用于标识二维码请求。
    accelerator-type
    加速器类型指定要创建的 Cloud TPU 的版本和大小。 如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    runtime-version
    [Cloud TPU 软件版本](/tpu/docs/supported-tpu-configurations#tpu_software_versions)。
    node-id
    将在响应 二维码请求。
    reserved
    创建切片时使用预留的配额。
    best-effort
    创建切片时尽量使用配额 [默认]。

    Google Cloud CLI 不支持某些二维码创建选项,例如标记。 如需了解详情,请参阅创建二维码

  4. 等待二维码进入 ACTIVE 状态,这意味着工作器节点处于 处于 READY 状态。二维码配置开始后,可能需要一到五个时间 所需时间长短取决于二维码的大小。您可以 检查相应的二维码:

    $ gcloud compute tpus queued-resources \
      list --filter=your-qr-id
    
  5. 一个 v4-8 切片具有一个 TPU 虚拟机。使用 SSH 连接到 TPU 虚拟机:

    $ gcloud compute tpus tpu-vm ssh your-qr-id
    
  6. 将 MaxText(包含 shardings.py)克隆到 TPU 虚拟机。

  7. 在 MaxText 代码库目录中,运行安装脚本进行安装 JAX 和 TPU 切片上的其他依赖项。设置脚本 分钟。

    $ bash setup.sh
    
  8. 运行以下命令,在 TPU 切片上运行 shardings.py

    $ python3 pedagogical_examples/shardings.py \
      --ici_fsdp_parallelism 4 \
      --batch_size 131072 \
      --embedding_dimension 2048
    

    您可以在日志中查看结果。您的 TPU 应该会达到大约 260 TFLOP 或可观的 90%以上 FLOP 利用率!在本示例中,我们 选择大约适合 TPU 高容量的 带宽内存 (HBM)。

  9. 欢迎随时探索其他分片策略 例如,您可以尝试以下组合:

    $ python3 pedagogical_examples/shardings.py \
      --ici_tensor_parallelism 4 \
      --batch_size 131072 \
      --embedding_dimension 2048
    
  10. 完成后删除二维码和 TPU 切片。您应该执行这些清理 在您设置 Slice 的环境中执行的步骤(首先运行 exit 到 退出 SSH 会话)。删除过程需要两到五分钟才能完成 可使用可选的 --async 标志在后台运行。

    $ gcloud compute tpus queued-resources
      delete your-qr-id --force (--async)
    

使用 DCN 并行处理的多切片分片

shardings.py 脚本接受三个参数,用于指定 DCN 并行处理, 对应于每种数据并行处理的分片数:

  • dcn_data_parallelism
  • dcn_fsdp_parallelism
  • dcn_tensor_parallelism

必须对这些参数的值进行约束,以便 dcn_data_parallelism * dcn_fsdp_parallelism * dcn_tensor_parallelism等于 切片数量。

以两个 Slice 为例,请使用 --dcn_data_parallelism = 2

dcn_data_parallelism dcn_fsdp_parallelism dcn_tensor_parallelism 切片数量
双向数据并行处理 2 1 1 2

dcn_tensor_parallelism 应始终设置为 1,因为 DCN 质量不佳 适合这种分片。对于 v4 芯片上的典型 LLM 工作负载, dcn_fsdp_parallelism 也应设置为 1,因此 dcn_data_parallelism 应设置为 Slice 的数量,但实际上是 取决于应用。

随着切片数量的增加(假设您保持切片大小和批次大小) 每个切片常量),则增加数据并行量。

在多切片环境中运行 shardings.py

您可以使用以下命令在多切片环境中运行 shardings.pymultihost_runner.py 或在每个 TPU 虚拟机上运行 shardings.py。在这里,我们使用 multihost_runner.py.以下步骤与 使用入门:针对多个切片的快速实验 来自 MaxText 代码库,不过在这里我们运行 shardings.py,而不是 train.py 中更复杂的 LLM。

multihost_runner.py 工具针对快速实验进行了优化,可反复运行 相同的 TPU由于 multihost_runner.py 脚本依赖于 长期有效的 SSH 连接,我们不建议将其用于任何长时间运行的作业。 如果您想运行时间更长(例如数小时或数天)的作业,我们建议您 请使用 multihost_job.py

在本教程中,我们使用术语“runner”来指示运行 运行 multihost_runner.py 脚本。我们用“工作器”一词表示 构成切片的 TPU 虚拟机。你可以在本地机器上运行 multihost_runner.py, 或与您的切片位于同一项目中的任何 Compute Engine 虚拟机。正在运行 不支持在工作器上使用 multihost_runner.py

multihost_runner.py 使用 SSH 自动连接到 TPU 工作器。

在此示例中,我们在两个 v4-16 切片(总共四个切片)上运行 shardings.py, 虚拟机和 16 个 TPU 芯片。您可以修改示例,以便在更多 TPU 上运行。

设置环境

  1. 在运行程序上克隆 MaxText 虚拟机。

  2. 转到代码库目录。

  3. gcloud 创建 SSH 密钥,我们建议将密码留空(按 在运行以下命令后输入两次)。如果系统提示您 google_compute_engine 文件已存在,选择不保留您的 现有版本。

      $ ssh-keygen -f ~/.ssh/google_compute_engine
      

  4. 添加一个环境变量,将 TPU 切片数量设置为 2

      $ export SLICE_COUNT=2
      

  5. 使用 queued-resources create 创建多切片环境。

    以下命令展示了如何创建 v4 Multislice TPU。要使用 v5e,请指定 v5e accelerator-type(例如 v5litepod-16)和 v5e runtime-version (v2-alpha-tpuv5-lite)。

      $ gcloud alpha compute tpus queued-resources 
    create your-qr-id
    --accelerator-type=your-accelerator-type
    --runtime-version=tpu-vm-runtime-version
    --node-count=node-count
    --node-prefix=your-qr-id
    [--reserved|--spot]

    命令标志说明

    your-qr-id
    一个用户定义的字符串,用于标识二维码请求。
    accelerator-type
    加速器类型指定要创建的 Cloud TPU 的版本和大小。 如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
    runtime-version
    Cloud TPU 软件版本
    node-count
    要创建的 Slice 数量。
    node-prefix
    用于为每个切片生成名称的前缀。已附加数字 添加到每个 Slice 的前缀中。例如,如果您将 node-prefix 更改为 mySlice,这些 Slice 将命名为 mySlice-0mySlice-1 等。
    reserved
    创建切片时使用预留的配额。
    best-effort
    创建切片时尽量使用配额 [默认]。

  6. 二维码配置开始时,最多可能需要五分钟才能完成 具体取决于 QR 码的大小。等待排队的资源 (QR) 进入 ACTIVE 状态。您可以使用 以下命令:

    $ gcloud compute tpus queued-resources list \
    --filter=your-qr-id
    

    此命令应生成如下所示的输出:

    NAME        ZONE           NODE_COUNT  ACCELERATOR_TYPE  STATE
    ...
    que-res-id  us-central2-b  4           v4-16             ACTIVE
    ...
    

    如果二维 Google Cloud 状态为 WAITING_FOR_RESOURCESPROVISIONING 状态超过 15 分钟。

  7. 安装依赖项。

    $ python3 multihost_runner.py \
      --TPU_PREFIX=your-qr-id \
      --COMMAND="bash setup.sh"
    
  8. 使用 multihost_runner.py 在每个工作器上运行 shardings.py

    $ python3 multihost_runner.py \
      --TPU_PREFIX=your-qr-id \
      --COMMAND="python3 pedagogical_examples/shardings.py \
      --dcn_data_parallelism $SLICE_COUNT \
      --ici_fsdp_parallelism 8 \
      --batch_size 131072 \
      --embedding_dimension 2048"
    

    您会在日志中看到每秒大约 230 TFLOP 的性能 文件。

  9. 完成后请清理 TPU 和二维码。删除过程需要两到五分钟 分钟完成,并且可以在后台使用可选的 --async 标志。

将工作负载扩缩为多切片

在多切片环境中运行模型之前, 以下代码更改:

这些应该是移至多切片后唯一需要进行的代码更改。 为了实现高性能,需要将 DCN 映射到并行数据, 并行轴或流水线并行轴。性能注意事项和 有关分片策略的详细介绍,请参见 使用多切片进行分片以获得最佳性能

要验证您的代码能否访问所有设备,您可以断言 len(jax.devices())等于多切片中的条状标签数量 环境例如,如果您使用四个 v4-16 切片,则 每个切片八个芯片 * 4 个切片,因此 len(jax.devices()) 应返回 32。

为多切片环境选择切片大小

要获得线性速度,请添加与现有切片相同大小的新切片 。例如,如果您使用 v4-512 切片,则多切片将 通过再添加一个 v4-512 切片来实现大约两倍的性能 并将全局批次大小加倍。如需了解详情,请参阅 使用多切片进行分片以获得最佳性能

在多个切片上运行作业

您可以通过三种不同的方法 多切片环境:

  1. 使用实验运行程序脚本 multihost_runner.py
  2. 使用生产运行程序脚本 multihost_job.py
  3. 使用手动方法

实验运行程序脚本

multihost_runner.py 脚本将代码分发到现有的多切片环境,并运行 在每个主机上运行您的命令,将日志复制回来,并跟踪每个命令的错误 状态。multihost_runner.py 脚本记录在 MaxText 自述文件

由于 multihost_runner.py 会维护永久性 SSH 连接,因此 适合规模适中、运行时间相对较短的实验。您可以 调整 multihost_runner.py 教程中的步骤 工作负载和硬件配置

生产环境运行程序脚本

适用于需要灵活应对硬件故障和其他故障的生产作业 最好直接与“创建排队的资源”集成 API.作为使用中的示例,我们提供了 multihost_job.py, 该调用会在适当的启动时触发“Created Queued Resource API”调用 脚本来运行训练并在抢占时恢复。multihost_job.py 脚本均记录在 MaxText 自述文件

由于 multihost_job.py 必须为每次运行预配资源,因此它不会 提供与 multihost_runner.py 一样快的迭代周期。

手动方法

我们建议您使用或调整 multihost_runner.pymultihost_job.py,以运行您的自定义工作负载 您的多切片配置。但是,如果您希望 直接使用 QR 命令来管理您的环境,请参阅 管理多切片环境

管理多切片环境

在不使用工具的情况下手动配置和管理二维码 MaxText 代码库中提供的应用,请阅读 后续部分。

创建二维码

在预配容量之前,请先设置以下环境变量:

  $ export your-qr-id=your-queued-resource-id
  $ export PROJECT=your-project-name
  $ export ZONE=us-central2-b
  $ export NETWORK_NAME=your-network-name
  $ export SUBNETWORK_NAME=your-subnetwork-name
  $ export RUNTIME_VERSION=tpu-ubuntu2204-base
  $ export ACCELERATOR_TYPE=v4-16
  $ export SLICE_COUNT=4
  $ export STARTUP_SCRIPT="#!/bin/bash\n ..."
  $ gcloud config set project project-name
  $ gcloud config set compute/zone zone
输入 说明
your-qr-id 用户指定的二维码 ID。
项目 Google Cloud 项目名称
可用区 us-central2-b
NETWORK_NAME VPC 网络的名称。
SUBNETWORK_NAME VPC 网络中子网的名称
RUNTIME_VERSION tpu-ubuntu2204-base
ACCELERATOR_TYPE v4-16
EXAMPLE_TAG_1、EXAMPLE_TAG_2... 用于标识网络防火墙的有效来源或目标的标记
SLICE_COUNT 切片数。上限为 256 个 Slice。
STARTUP_SCRIPT 如果被添加到创建请求中, 启动脚本可以在每次预配或重启 TPU 切片时运行 以及 TPU 切片是否已修复或重置。

使用 gcloud 创建二维码请求

$ gcloud alpha compute tpus queued-resources \
  create ${your-qr-id} \
  --project your-project-id \
  --zone your-zone \
  --node-count ${SLICE_COUNT} \
  --accelerator-type ${ACCELERATOR_TYPE} \
  --runtime-version ${RUNTIME_VERSION} \
  --network ${NETWORK_NAME} \
  --subnetwork ${SUBNETWORK_NAME} \
  --tags ${EXAMPLE_TAG_1},${EXAMPLE_TAG_2} \ --metadata=startup-script='${STARTUP_SCRIPT}'
  [--reserved|--spot]
  

命令标志说明

your-qr-id
一个用户定义的字符串,用于标识二维码请求。
project
一个用户定义的字符串,用于标识二维码请求。
zone
要在其中创建二维码的 Google Cloud 区域。
node-count
要创建的 Slice 数量。
accelerator-type
加速器类型指定要创建的 Cloud TPU 的版本和大小。 如需详细了解每个 TPU 版本支持的加速器类型,请参阅 TPU 版本
runtime-version
Cloud TPU 软件版本
network
要挂接 TPU 资源的 VPC 网络的名称。
subnetwork
要挂接 TPU 资源的 VPC 子网的名称。
reserved
创建切片时使用预留的配额。
spot
创建切片时使用 Spot 虚拟机配额。

在选择 --reserved 之前,请确保您拥有相应的配额。 --spot 或默认的按需配额。如需了解配额类型 请参阅配额政策

使用 curl 创建二维码请求

创建名为 queued-resource-req.json 的文件,并将以下 JSON 复制到其中。

{
  "guaranteed": { "reserved": true },
  "tpu": {
    "node_spec": [
    {
      "parent": "projects/your-project-number/locations/your-zone",
        "node": {
          "accelerator_type": "accelerator-type",
          "runtime_version": "tpu-vm-runtime-version",
          "network_config": {
            "network": "your-network-name",
            "subnetwork": "your-subnetwork-name",
            "enable_external_ips": true
          },
          "tags" : ["example-tag-1"]
          "metadata": {
            "startup-script": "your-startup-script"
          }
      },
      "multi_node_params": {
        "node_count": slice-count,
        "node_id_prefix": "your-queued-resource-id"
      }
    }
    ]
  }
}
  • your-project-number - 您的 Google Cloud 项目编号
  • your-zone - 要在其中创建二维码的区域
  • accelerator-type - 单个 Slice 的版本和大小
  • tpu-vm-runtime-version - TPU 虚拟机运行时版本
  • your-network-name -(可选)要附加二维码的网络
  • your-subnetwork-name -(可选)要附加二维码的子网
  • example-tag-1 - 可选,任意代码字符串
  • your-startup-script - 分配二维码时将运行的启动脚本
  • slice-count - 多切片环境中的 TPU 切片的数量
  • your-qr-id - 用户为二维码提供的 ID

如需了解详情,请参阅 REST 已排队的资源 API 所有可用选项的文档。

如需使用 Spot 容量,请替换以下内容:

"guaranteed": { "reserved": true } - "spot": {}

请移除此行以使用默认的按需容量。

使用 JSON 载荷提交二维码创建请求:

  $ curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" -d @queuedresourcereq.json https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/queuedResources\?queued_resource_id\=your-qr-id
  • your-project-id - 您的 Google Cloud 项目 ID
  • your-zone - 要在其中创建二维码的区域
  • your-qr-id - 用户为二维码提供的 ID

响应应如下所示:

{
  "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qr-guid>",
  "metadata": {
    "@type": "type.googleapis.com/google.cloud.common.OperationMetadata",
    "createTime": "2023-11-01T00:17:05.742546311Z",
    "target": "projects/<your-project-id>/locations/<your-zone>/queuedResources/<your-qa-id>",
    "verb": "create",
    "cancelRequested": false,
    "apiVersion": "v2alpha1"
  },
  "done": false
}

使用 name 属性的字符串值末尾的 GUID 值来获取 关于二维码请求的信息。

检索二维码的状态

如需获取二维码请求的状态,请使用以下命令:

  $ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/your-project-id/locations/your-zone/operations/operation-your-qr-guid
  • your-project-id - 您的 Google Cloud 项目 ID
  • your-zone - 要在其中创建二维码的区域
  • your-qr-guid - 指定 API 输出中 name 后面的 GUID 二维码创建请求。

此命令的响应包含操作的状态:

{
  "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qa-guid>,
  "metadata": {...},
  "done": true,
  "response": {
    "@type": "type.googleapis.com/google.cloud.tpu.v2.QueuedResource",
    ...
    "state": {
      "state": "WAITING_FOR_RESOURCES"
    }
  }
}

如果二维码创建成功 ("done = true"),则 response 字段将为 WAITING_FOR_RESOURCESFAILED。 如果二维码处于 WAITING_FOR_RESOURCES 状态,则表示其已 已加入队列,并在有足够的资源后开始预配。如果二维码 处于 FAILED 状态,输出中会显示失败原因。有关 有关其他可能状态的信息,请参阅 已加入队列的资源用户指南

操作完成后,请使用描述二维码 来监控二维码的各个阶段。

在极少数情况下,您可能会发现二维码处于 FAILED 状态,而一些 Slice 为 ACTIVE。如果发生这种情况,请删除已创建的资源 然后过几分钟再试,或者与我们联系 联系 Cloud TPU 团队解决此问题。

SSH 和安装依赖项

在 TPU Pod 切片上运行 JAX 代码 介绍了如何在单个切片中使用 SSH 连接到您的 TPU 虚拟机。接收者 通过 SSH 连接到多切片环境中的所有 TPU 虚拟机, 使用以下 gcloud 命令安装依赖项:

  $ gcloud compute tpus queued-resources ssh ${your-qr-id} \
    --zone your-zone \
    --node=all \
    --worker=all \
    --command="command-to-run"
    --batch-size=4

gcloud 命令将指定的命令发送至以下域中的所有工作器和节点: 使用 SSH 进行二维码。该命令会被分成四组分批发送 。在当前批次时,系统会发送下一批命令 完成执行。如果其中某个命令出现故障, 停止,并且不再发送其他批次。有关详情,请参阅 已加入队列的资源 API 参考文档。 如果您使用的切片数量超过本地计算机的线程处理 则会导致死锁。例如 假设本地机器上的批处理限制为 64。如果您尝试 超过 64 个切片(比如 100 个切片)上执行训练脚本时,SSH 命令会破坏 将切片分成几批。它将在第一批 64 和 64 并等待脚本完成运行,然后再在 剩下的 36 个切片。但是,第一批 64 个切片不能 直到剩下的 36 个切片开始运行脚本, 死锁。

为防止出现这种情况,您可以在 在您指定的脚本命令中附加和号 (&),为每个虚拟机 。--command执行此操作时,在启动训练脚本后, 在第一批切片上,控制权将立即恢复为 SSH 命令然后,SSH 命令可以开始在 剩下的 36 个切片。您需要为 stdoutstderr 使用管道。 在后台运行命令时进行适当的流式传输。增加 并行处理,则可以使用 --node 选择特定 Slice 参数。

网络设置

通过执行以下步骤,确保 TPU 切片可以相互通信。 在每个切片上安装 JAX。如需了解详情,请参阅 在 TPU Pod 切片上运行 JAX 代码。断言 len(jax.devices())等于多切片中的条状标签数量 环境为此,请在每个 Slice 上运行:

  $ python3 -c 'import jax; print(jax.devices())'

如果您在 v4-16 的四个切片上运行此代码,则每个 Slice 和 4 Slice,总共应返回 32 个芯片(设备) 上传者:jax.devices()

列出快速回复

您可以使用 queued-resources list 命令查看二维码的状态:

$ gcloud compute tpus queued-resources list

NAME        ZONE           NODE_COUNT  ACCELERATOR_TYPE  STATE
...
que-res-id  us-central2-b  4           v4-16             ACTIVE
...

描述二维码

要查看二维码的详细配置和状态,请使用 QR API 的说明。您可以使用 gcloudcurl 调用此 API。

使用 gcloud

$ gcloud compute tpus queued-resources describe ${your-qr-id}
...state:
 state: ACTIVE
...

使用 curl

$ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/your-project-id/locations/your-zone/queuedResources/${your-qr-id}
{
  "name": your-queued-res,
  "tpu": {
    "nodeSpec": [
      {
        ... // node 1
      },
      {
        ... // node 2
      },
      ...
    ]
  },
  ...
  "state": "ACTIVE"
}

state 表示二维码的状态。如需详细了解 二维码的状态,请参阅已加入队列的资源

在预配的环境中启动作业

您可以通过 SSH 连接到每个切片中的所有主机,从而手动运行工作负载 并在所有主机上运行以下命令。

$ gcloud compute tpus tpu-vm ssh your-qr-id \
  --zone=your-zone \
  --worker=all \
  --node=all \
  --command="command-to-run"

重置二维码

ResetQueuedResource API 可用于重置 ACTIVE QR 码中的所有虚拟机。重置虚拟机会强制清空 并将虚拟机重置为初始状态。本地存储的所有数据都会 且在重置后会调用启动脚本。通过 如果您想重启所有 TPU,ResetQueuedResource API 会很有用。对于 例如训练停滞,重置所有虚拟机比调试更容易。

所有虚拟机的重置都是并行执行的,并且 ResetQueuedResource 需要一到两分钟才能完成如需调用该 API,请使用以下命令 命令:

$ gcloud compute tpus queued-resources reset your-qr-id

删除二维码

如需在培训课程结束时释放资源,请删除已排入队列的资源 使用 --force 标志管理资源。删除过程将需要两到五分钟的时间 完成,可使用可选的 --async 标志在后台运行。

$ gcloud compute tpus queued-resources \
delete your-qr-id --force (--async)

自动故障恢复

如果发生服务中断,Multislice 提供无干预 修复受影响的切片并重置所有切片受影响的 将替换为新的切片,并将其余健康状况良好的切片 重置。如果 则训练会停止。

要在中断后自动继续训练,您必须指定 启动脚本,用于检查 并加载上次保存的检查点。您的启动脚本会自动运行 系统在每次重新分配切片或重置虚拟机时触发。你指定一家初创公司 脚本。

以下启动脚本(在创建二维码中使用) 让您可以从故障中自动恢复并从中恢复训练, MaxText 训练期间存储在 Cloud Storage 存储桶中的检查点:

{
 "tpu": {
   "node_spec": [
     {
      ...
         "metadata": {
               "startup-script": "#! /bin/bash \n pwd \n runuser -l user1 -c 'cd /home/user1/MaxText && python3 MaxText/train.py MaxText/configs/base.yml run_name=run_test_failure_recovery dcn_data_parallelism=4 ici_fsdp_parallelism=8 steps=10000 save_period=10 base_output_directory='gs://user1-us-central2'' EOF"
         }
     ...
     }
   ]
 }
}

请先克隆 MaxText 代码库,然后再尝试执行此操作 。

性能分析和调试

在单切片和多切片环境中进行性能分析是相同的。对于 如需了解详情,请参阅剖析 JAX 程序的性能

优化培训

使用多切片进行分片,以获得最佳性能

若要在多切片环境中实现最高性能,需要 考虑如何在多个切片中进行分片。通常有三种 选项(数据并行处理、完全分片数据并行处理和流水线并行处理)。 我们不建议跨模型维度对激活进行分片(有时 称为并行张量),因为它需要过多的切片间带宽。 对于所有这些策略,您可以在一个切片中使用相同的分片策略 这一方法。

我们建议从纯数据并行处理入手。使用完全分片的数据 同时并行处理有助于释放内存。但其缺点是 切片之间的通信会使用 DCN 网络, 工作负载仅在必要时根据批次大小使用流水线并行处理 (如下文分析)。

何时使用数据并行处理

如果您的工作负载 但您希望通过横向伸缩 多个切片。

要跨多个切片实现强有力的伸缩 在 DCN 上执行全部缩减操作所需的时间必须短于在 DCN 上 以便执行向后传递。DCN 用于切片与 是工作负载吞吐量的限制因素。

每个 v4 TPU 芯片的执行速度峰值为每秒 275 * 1012 FLOPS。

每个 TPU 主机有四个芯片,每个主机都有最大网络带宽 50 Gbps 的带宽

也就是说,算术强度 为 4 * 275 * 1012 FLOPS / 50 Gbps = 22000 FLOPS / 位。

您的模型会为每个步骤的每个参数使用 32 到 64 位的 DCN 带宽。 如果您使用两个切片,您的模型将使用 32 位的 DCN 带宽。如果您 使用两个以上的 Slice,则编译器会执行完全 shuffle 和 all-reduce 每个参数可以使用多达 64 位的 DCN 带宽, 操作。每个参数所需的 FLOPS 数量因您的 model.具体而言,对于基于 Transformer 的语言模型,FLOPS 的数量 所需的正向和反向传递大约为 6 * B * P,其中:

  • B 是批次大小,以词元为单位
  • P 是参数数量

每个参数的 FLOPS 数量为 6 * B,而 FLOPS 数量为每个参数 为 4 * B

为确保跨多个切片实现强有力的伸缩,请确保运维套件 强度超过了 TPU 硬件的算术强度。要计算 计算操作强度,请用每个参数的 FLOPS 数量除以 每个步骤的每个参数向后传递网络带宽(以位为单位): Operational Intensity = FLOPSbackwards_pass / DCN bandwidth

因此,对于基于 Transformer 的语言模型,如果您使用两个切片: Operational intensity = 4 * B / 32

如果您使用两个以上的 Slice:Operational intensity = 4 * B/64

这表明 Transformer 的最小批次大小在 176k 到 352k 之间 构建自己的语言模型。由于 DCN 网络可能短暂地丢弃数据包 最好保留明显的容错空间,仅部署数据并行处理 如果每个 Pod 的批次大小至少为 350k(两个 Pod)到 700k(许多 Pod)。

对于其他模型架构,您需要估算应用的 每个切片向后传递(使用性能分析器计时或计数) FLOPS)。然后,您可以将其与预期运行时间进行比较, DCN,并大致了解数据并行处理是否适合您。

何时使用完全分片数据并行处理 (FSDP)

完全分片数据并行处理 (FSDP) 结合了数据并行处理(将 并将权重分片到各节点。对于 向前传递和向后传递,权重会得到收集, 具有所需的权重。我们不像使用 all-reduce,则梯度在生成时进行归约散射。这样, 每个切片仅获取其所负责的权重的梯度。

与数据并行处理类似,FSDP 将需要伸缩全局批次大小 与切片数量呈线性关系。FSDP 会减少内存压力,因为 您需要增加切片的数量。这是因为 每个切片的优化器状态减少,但代价是增加的 并更有可能因 组织。

实际上,如果您按每个切片的批量大小增加批量大小, 存储更多激活,以最大限度地减少在 向后传递或增加神经网络中的参数数量。

FSDP 中的 all-gather 和 all-reduce 操作与 DP 中的类似。 因此您可以确定 FSDP 工作负载是否受到 DCN 性能的限制, 创建容器。

何时使用流水线并行处理

在使用其他 Google Cloud 产品实现高性能时, 并行处理策略,这些策略要求全局批次大小超过 首选最大批次大小。流水线并行处理允许切片 组成流水线,以“共享”生成一批。然而,流水线并行处理 明显的缺点:

  1. 会产生“管道气泡”芯片处于空闲状态 数据。
  2. 它需要微批处理,这减小了有效批次大小、 算术强度,并最终对 FLOP 利用率进行建模。

仅当存在其他并行策略时,才应使用流水线并行处理功能 全局批次大小过大。在尝试并行处理流水线之前 可以凭经验判断每个样本的收敛速度是否在 实现高性能 FSDP 所需的批次大小。FSDP 往往能够 更高的模型 FLOP 利用率,但如果每个样本的收敛速度随着 批次大小流水线并行处理可能仍是更好的选择。大多数人 工作负载可以容忍足够大的批次大小,因此无法从 但您的工作负载可能有所不同

如果需要流水线并行处理,我们建议将其与数据结合使用 即 FSDP。这样,您就可以最大程度地减少流水线深度 增加每个流水线批次大小,直到 DCN 延迟时间减少 都要考虑吞吐量具体而言,如果您有 N 个切片,不妨考虑 数据并行处理深度为 2 和 N/2 的副本,然后为深度为 4 和 N/4 的流水线 数据并行副本等,直到每个流水线的批次变大 DCN 集合可以隐藏在 反向传递。这样可以最大限度地缩短流水线 同时允许扩容超过全局批次大小限制。

多切片最佳实践

数据加载

在训练期间,我们会反复加载数据集中的批次,以馈送到 model.使用高效的异步数据加载器,将批量数据分片 对避免资源耗尽 TPU 的工作负担至关重要。当前数据加载器 让每个主机加载相等的样本子集。此解决方案是 足以处理文本,但需要在模型中重新分片。此外,MaxText 目前尚未提供允许数据迭代器的确定性快照 以便在抢占前和抢占后加载相同的数据。

检查点

Orbax 检查点库提供了 用于将 JAX PyTree 建立检查点到本地存储空间或 Google Cloud Storage 的基元。 我们在 MaxText 中提供了同步检查点的参考集成 在 checkpointing.py 中。

受支持的配置

形状

所有 Slice 的形状必须相同(例如,相同的 AcceleratorType)。 不支持异构切片形状。

编排

GKE 支持编排。如需更多信息 请参阅 GKE 中的 TPU

框架

多切片仅支持 JAX 和 PyTorch 工作负载。

最大并行数量

我们建议用户使用数据并行处理功能测试多切片。了解详情 了解如何使用多切片实现流水线并行处理,请联系您的 Google Cloud 客户代表。

支持与反馈

我们欢迎各种反馈!如需分享反馈或请求支持,请与我们联系 填写 Cloud TPU 支持或反馈表单