Cloud TPU Autocheckpoint [公开预览版]

概览

过去,当 TPU 虚拟机需要维护时,该过程会立即启动,用户无需留出时间来执行进度保留操作(例如保存检查点)。如图 1(a) 所示。

自动检查点

图 1.自动检查点功能的图示: (a) 如果没有自动检查点,当有即将进行的维护事件时,上一个检查点的训练进度将会丢失。(b) 借助 Autocheckpoint,当有即将进行的维护事件时,可以保留自上一个检查点以来的训练进度。

您可以使用自动检查点(图 1(b))来保留训练进度,只需将代码配置为在维护事件发生时保存非预定检查点即可。发生维护事件时,系统会自动保存自上一个检查点以来的进度。该功能适用于单个切片和多切片。

Autocheckpoint 功能可与可以捕获 SIGTERM 并随后保存检查点的框架搭配使用。支持的框架包括 MaxTextPax 以及具有 Orbax 的 JAX。我们计划在推出更多框架后公布对相应框架的支持。

目前,只有通过 Cloud TPU API 创建的 TPU(v2-v4 和 v5e)才能使用此功能。我们会在 GKE 推出对 TPU 的支持时公布。

使用 Autocheckpoint

自动检查点功能默认处于停用状态。创建 TPU 或排队资源时,您可以在预配 TPU 时添加 --autocheckpoint-enabled 标志来启用 TPU。启用此功能后,Cloud TPU 会在收到维护事件通知后执行以下步骤:

  1. 捕获使用 TPU 设备发送到进程的 SIGTERM,
  2. 等待进程退出,或者等待 5 分钟(以先发生者为准),并对受影响的切片执行维护。

请注意,Autocheckpoint 使用的基础架构独立于机器学习框架。 任何机器学习框架都可以支持 Autocheckpoint,但前提是它可以捕获 SIGTERM 信号并启动检查点进程。

在应用代码中,您需要启用机器学习框架提供的 Autocheckpoint 功能。例如,在 Pax 中,这意味着在启动训练时启用命令行标志(请参阅 Pax 自动检查点快速入门)。在后台,收到 SIGTERM 后,框架会保存非计划检查点;当不再使用 TPU 时,受影响的 TPU 虚拟机进行维护。

快速入门:使用 MaxText 进行自动检查

MaxText 是“以 Cloud TPU 为目标、使用纯 Python/JAX 编写且经过测试的高性能、可任意伸缩的开源 LLM”。 MaxText 包含使用自动检查点功能所需的所有设置。

MaxText 自述文件介绍了两种大规模运行 MaxText 的方法:

使用 multihost_runner.py 时,唯一需要的更改是在预配排队的资源时设置 autocheckpoint-enabled 标志。使用 multihost_job.py 时,唯一需要的更改是在启动作业时指定 ENABLE_AUTOCHECKPOINT=true 命令行 flag。

快速入门:在单个切片上使用 Pax 自动检查点

在本部分中,我们将举例说明如何在单个 Slice 中通过 Pax 设置和使用 Autocheckpoint。进行适当的设置后:

  • 维护事件发生时,系统会保存检查点。
  • 保存检查点后,Cloud TPU 将对受影响的 TPU 虚拟机执行维护。
  • Cloud TPU 完成维护后,您可以照常使用 TPU 虚拟机。
  1. 在创建 TPU 虚拟机或已加入队列的资源时使用 autocheckpoint-enabled 标志。

    例如:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    
  2. 在单个切片上安装 Pax

    自动检查点功能适用于 1.1.0 或更高版本的 Pax 版本。在 TPU 虚拟机上,安装 jax[tpu] 和最新的 paxml

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. 使用适当的配置启动训练

    以下示例展示了如何配置 LmCloudSpmd2B 模型,以将 Autocheckpoint 触发的检查点保存到 Google Cloud Storage 存储桶:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
    

    请注意传递给命令的两个标志:

    • jax_fully_async_checkpoint:此标志开启时,将使用 orbax.checkpoint.AsyncCheckpointer。 当训练脚本收到 SIGTERM 信号时,AsyncCheckpointer 类会自动保存检查点。
    • exit_after_ondemand_checkpoint:启用此标志后,TPU 进程会在自动检查点成功保存后退出,从而触发立即执行维护。如果您不使用此标志,训练将在检查点保存后继续进行,并且 Cloud TPU 将等待超时(5 分钟)后再执行所需的维护。

快速入门:在 MultiSlice 上使用 Pax 进行自动检查

自动检查点不仅适用于单个切片,也适用于多切片。本部分详细介绍了将 Autocheckpoint 与 MultiSlice 搭配使用所需的步骤。

  1. 在创建排队的资源期间指定自动检查点。

    多切片环境只能通过排队的资源请求来预配。与单切片的情况类似,请在调用中使用 autocheckpoint-enabled 标志来创建排队资源。

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud alpha compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    

    如需详细了解所有可用选项,请参阅多切片用户指南。创建已加入队列的资源请求并处于 ACTIVE 状态后,请按照后续步骤通过 Autocheckpoint 运行 Pax。

  2. 在多切片环境中的所有虚拟机上安装 Pax。

    在 TPU 虚拟机上,在 MultiSlice 环境中的所有 TPU 虚拟机上安装 jax[tpu] 和最新的 paxml

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. 使用适当的配置启动训练

    以下示例展示了如何在多切片环境中为 Autocheckpoint 配置模型 LmCloudSpmd2B。在运行训练脚本之前,请将 DCN_MESH_SHAPE 设置为 [2, 1, 1],如以下代码所示:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]
    

    启动训练时,除了单切片案例中讨论的命令行标志外,还需要另外三个标志:

    • num_hosts:主机总数。在本例中,标签为 2。
    • host_index:启动训练的主机的索引。该值介于 0 到 N-1 之间,其中 N 是主机总数。
    • server_addr:节点 0 的工作器 0 的 IP 地址(具有未使用的端口,例如 8476)。如需查找此信息,请在节点 0 的工作器 0 上使用 hostname -i

使用 Orbax 自动检查点

Autocheckpoint 功能不仅限于 MaxText 或 Pax。任何能够捕获 SIGTERM 信号并启动检查点流程的框架都适用于 Autocheckpoint 提供的基础架构。Orbax 是一个为 JAX 用户提供常用实用程序库的命名空间,也提供了这些功能。

Orbax 文档中所述,默认情况下,系统会为 orbax.checkpoint.CheckpointManager 的用户启用这些功能。在每个步骤之后调用的 save 方法会自动检查维护事件是否即将发生,如果是,则即使步骤编号不是 save_interval_steps 的倍数,也会保存检查点。GitHub 文档还说明了如何在保存自动检查点后通过修改用户代码来退出训练。