使用 Autocheckpoint 保留训练进度

过去,当 TPU 虚拟机需要维护时,系统会立即启动该过程,而无需等待用户执行保存检查点等可保留进度的操作。如图 1(a) 所示。

显示启用和不启用自动检查点的主机维护影响的示意图

图 1. 自动检查点功能示意图:(a) 如果未启用自动检查点,当即将发生维护事件时,上一个检查点的训练进度会丢失。(b) 使用 Autocheckpoint 时,在即将发生维护事件时,系统可以保留自上次检查点以来的训练进度。

您可以使用 Autocheckpoint(图 1(b))来保留训练进度,方法是将代码配置为在发生维护事件时保存非预定的检查点。发生维护事件时,系统会自动保存自上次检查点以来的进度。此功能适用于单个 Slice 和 Multislice。

Autocheckpoint 功能适用于可捕获 SIGTERM 信号并随后保存检查点的框架。支持的框架包括:

使用 Autocheckpoint

自动检查点功能默认处于停用状态。创建 TPU 或请求已排队的资源时,您可以在预配 TPU 时添加 --autocheckpoint-enabled 标志,以启用自动检查点。启用此功能后,Cloud TPU 在收到维护事件通知后会执行以下步骤:

  1. 捕获使用 TPU 设备向进程发送的 SIGTERM 信号
  2. 等待进程退出或 5 分钟过后(以先到者为准)
  3. 对受影响的 Slice 执行维护

Autocheckpoint 使用的基础架构与机器学习框架无关。如果任何机器学习框架可以捕获 SIGTERM 信号并启动检查点过程,则可以支持 Autocheckpoint。

在应用代码中,您需要启用机器学习框架提供的 Autocheckpoint 功能。例如,在 Pax 中,这意味着在启动训练时启用命令行标志。如需了解详情,请参阅使用 Pax 的自动检查点快速入门。在后台,框架会在收到 SIGTERM 信号时保存非调度检查点,并且当 TPU 不再使用时,受影响的 TPU 虚拟机会进行维护。

快速入门:使用 MaxText 自动创建检查点

MaxText 是一个高性能、任意可伸缩、开源且经过充分测试的 LLM,以纯 Python/JAX 编写,以 Cloud TPU 为目标平台。MaxText 包含使用 Autocheckpoint 功能所需的所有设置。

MaxText README 文件介绍了两种大规模运行 MaxText 的方法:

使用 multihost_runner.py 时,请在预配队列资源时设置 autocheckpoint-enabled 标志,以启用 Autocheckpoint。

使用 multihost_job.py 时,请在启动作业时指定 ENABLE_AUTOCHECKPOINT=true 命令行 flag,以启用自动检查点。

快速入门:在单个 slice 上使用 Pax 自动创建检查点

本部分通过示例介绍了如何在单个 slice 上设置和使用 Autocheckpoint 与 Pax。通过适当的设置:

  • 当发生维护事件时,系统会保存一个检查点。
  • 保存检查点后,Cloud TPU 将对受影响的 TPU 虚拟机执行维护。
  • Cloud TPU 完成维护后,您可以照常使用 TPU 虚拟机。
  1. 创建 TPU 虚拟机或请求队列化资源时,请使用 autocheckpoint-enabled 标志。

    例如:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
        --accelerator-type $ACCELERATOR_TYPE \
        --version tpu-ubuntu2204-base \
        --autocheckpoint-enabled
  2. 在单个 slice 上安装 Pax

    Autocheckpoint 功能适用于 Pax 1.1.0 及更高版本。在 TPU 虚拟机上,安装 jax[tpu] 和最新的 paxml

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. 使用适当的配置启动训练。

    以下示例展示了如何配置 LmCloudSpmd2B 模型,以将 Autocheckpoint 触发的检查点保存到 Cloud Storage 存储桶:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    请注意传递给该命令的两个标志:

    • jax_fully_async_checkpoint:启用此标志后,系统将使用 orbax.checkpoint.AsyncCheckpointer。当训练脚本收到 SIGTERM 信号时,AsyncCheckpointer 类会自动保存检查点。
    • exit_after_ondemand_checkpoint:如果此标志处于启用状态,则 TPU 进程会在自动检查点成功保存后退出,这会触发立即执行维护。如果您不使用此标志,则在保存检查点后,训练将继续进行,并且 Cloud TPU 将等待超时(5 分钟)后再执行所需的维护。

快速入门:在多 Slice 上使用 Pax 进行自动检查点

自动检查点不仅适用于单个 Slice,也适用于多 Slice。本部分详细介绍了将 Autocheckpoint 与 Multislice 搭配使用所需的步骤。

  1. 在队列化资源创建期间指定 Autocheckpoint。

    只能通过排队的资源请求预配多 Slice 环境。与单个 slice 情形类似,在调用中使用 autocheckpoint-enabled 标志创建队列化资源。

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud compute tpus queued-resources create $QR_ID \
        --node-count $NODE_COUNT \
        --accelerator-type $ACCELERATOR_TYPE \
        --runtime-version tpu-ubuntu2204-base \
        --autocheckpoint-enabled

    如需详细了解所有可用选项,请参阅多片解析用户指南。创建队列中的资源请求并将其置于 ACTIVE 状态后,请按照后续步骤使用 Autocheckpoint 运行 Pax。

  2. 在多 slice 环境中的所有虚拟机上安装 Pax。

    在 TPU 虚拟机上,在多切片环境中的所有 TPU 虚拟机上安装 jax[tpu] 和最新的 paxml

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. 使用适当的配置启动训练。

    此示例展示了如何在多 Slice 环境中训练时为 LmCloudSpmd2B 模型配置自动检查点。在运行训练脚本之前,请将 DCN_MESH_SHAPE 设置为 [2, 1, 1],如以下示例所示:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 4, 1]
        DCN_MESH_SHAPE = [2, 1, 1]

    启动训练时,除了单个 Slice 情况下讨论的命令行标志之外,还需要另外三个标志:

    • num_hosts:主机总数。在本例中,它为 2。
    • host_index:启动训练的主机的编号。该值介于 0 到 N-1 之间,其中 N 是主机总数。
    • server_addr:节点 0 的工作器 0 的 IP 地址,以及未使用的端口(例如 8476)。如需查找此信息,请对节点 0 的工作器 0 使用 hostname -i

使用 Orbax 进行自动检查点

Autocheckpoint 功能不仅适用于 MaxText 或 Pax。任何能够捕获 SIGTERM 信号并启动检查点生成进程的框架都可以与 Autocheckpoint 提供的基础架构搭配使用。Orbax 是一个为 JAX 用户提供常用实用程序库的命名空间,可提供这些功能。

Orbax 文档中所述,系统会默认为 orbax.checkpoint.CheckpointManager 用户启用这些功能。在每个步骤后调用的 save 方法会自动检查是否即将发生维护事件,如果即将发生,则保存检查点,即使步骤编号不是 save_interval_steps 的倍数也是如此。GitHub 文档还介绍了如何通过修改用户代码,在保存 Autocheckpoint 后让训练退出。