Cloud TPU 自动检查点 [公开预览版]
概览
过去,当 TPU 虚拟机 维护、 程序会立即启动,无需用户留出时间 执行保存进度的操作(例如保存检查点)。如图 1(a) 所示。
图 1. 自动检查点功能示意图: (a) 如果没有自动检查点,则使用上一个检查点训练的进度 会在有即将进行的维护事件时丢失(b) 使用 Autocheckpoint 时,当有即将进行的维护事件时,系统可以保留自上次检查点以来的训练进度。
您可以使用自动检查点(图 1(b))来保留训练进度, 配置代码以在维护期间保存非计划检查点 事件。发生维护事件时,系统会自动保存自上次检查点以来的进度。此功能适用于单个 Slice 和 Multislice。
Autocheckpoint 功能适用于可以捕获 SIGTERM 并随后保存检查点的框架。支持的框架包括 MaxText, Pax、 和 JAX 和 Orbax。 我们会在支持更多框架后及时公布相关信息。
目前,只有通过 Cloud TPU API 创建的 TPU(v2-v4 和 v5e)才能使用此功能。对 GKE 中 TPU 的支持将进一步提升
使用 Autocheckpoint
自动检查点功能默认处于停用状态。创建
TPU 或已加入队列的资源,
您可以在预配时添加 --autocheckpoint-enabled
标志来启用它
TPU
启用此功能后,Cloud TPU 在收到维护事件通知后会执行以下步骤:
- 捕获使用 TPU 设备向进程发送的 SIGTERM
- 等待进程退出或 5 分钟(以先到者为准),然后对受影响的 Slice 执行维护。
请注意,Autocheckpoint 使用的基础架构与机器学习框架无关。只要机器学习框架能够捕获 SIGTERM 信号并启动检查点过程,便可支持 Autocheckpoint。
在应用代码中,您需要启用机器学习框架提供的 Autocheckpoint 功能。例如,在 Pax 中 这意味着在启动 训练(请参阅 Pax 的自动检查点快速入门)。 在后台,框架会在收到 SIGTERM 时保存非调度检查点,并且当 TPU 不再使用时,受影响的 TPU 虚拟机会进行维护。
快速入门:使用 MaxText 自动创建检查点
MaxText 是一个“高性能, 使用纯 Python/JAX 编写的经过良好测试的可任意伸缩的开源 LLM 以 Cloud TPU 为目标平台”。 MaxText 包含使用自动检查点所需的所有设置 功能。
MaxText 自述文件介绍了两种大规模运行 MaxText 的方法:
- 使用
multihost_runner.py
,建议用于实验 - 使用
multihost_job.job
(建议用于生产环境)
使用 multihost_runner.py
时,唯一需要进行的更改
在预配时设置 autocheckpoint-enabled
标志
已加入队列的资源。使用 multihost_job.py
时,唯一需要更改的是,在启动作业时指定 ENABLE_AUTOCHECKPOINT=true
命令行标志。
快速入门:在单个 Slice 上使用 Pax 实现自动检查点
在本部分中,我们将举例说明如何设置和使用自动检查点 Pax 都在一个切片上。通过适当的设置:
- 当发生维护事件时,系统会保存一个检查点。
- 保存检查点后,Cloud TPU 将对受影响的 TPU 虚拟机执行维护。
- 在 Cloud TPU 完成维护后,您可以照常使用 TPU 虚拟机。
在创建 TPU 虚拟机时使用
autocheckpoint-enabled
标志,或 已加入队列的资源。例如:
PROJECT=your-gcp-project-name ZONE=zone-you-want-to-use NODE_ID=your-node-id ACCELERATOR_TYPE=your-accelerator-type gcloud config set project $PROJECT gcloud config set compute/zone $ZONE
gcloud alpha compute tpus tpu-vm create $NODE_ID \ --accelerator-type $ACCELERATOR_TYPE \ --version tpu-ubuntu2204-base \ --autocheckpoint-enabled
在单个 slice 上安装 Pax
自动检查点功能适用于 Pax 1.1.0 或更高版本。在 TPU 虚拟机上,安装
jax[tpu]
和最新的paxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
使用适当的配置启动训练
以下示例展示了如何配置
LmCloudSpmd2B
模型,以将 Autocheckpoint 触发的检查点保存到 Google Cloud Storage 存储桶:JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
请注意传递给此命令的两个标志:
jax_fully_async_checkpoint
:启用此标志后,系统将使用orbax.checkpoint.AsyncCheckpointer
。AsyncCheckpointer
类会自动保存 在训练脚本收到 SIGTERM 信号时启动一个检查点。exit_after_ondemand_checkpoint
:如果此标志处于启用状态,则 TPU 进程会在自动检查点成功保存后退出,这会触发立即执行维护。如果您不使用 标志后,训练将在保存检查点后继续 Cloud TPU 将等待超时(5 分钟) 然后再执行所需的维护。
快速入门:在多 Slice 上使用 Pax 进行自动检查点
自动检查点不仅适用于单个 Slice,也适用于多 Slice。本部分详细介绍了将 Autocheckpoint 与 Multislice 搭配使用所需的步骤。
在创建排队的资源期间指定自动检查点。
只能通过排队的资源请求预配多 Slice 环境。与单个 slice 的情况类似,在调用中使用
autocheckpoint-enabled
标志可创建队列化资源。QR_ID=your-qr-id NODE_COUNT=your-node-count ACCELERATOR_TYPE=your-accelerator-type gcloud compute tpus queued-resources create $QR_ID \ --node-count $NODE_COUNT \ --accelerator-type $ACCELERATOR_TYPE \ --runtime-version tpu-ubuntu2204-base \ --autocheckpoint-enabled
如需详细了解所有可用选项,请参阅多 Slice 用户指南。队列化资源请求创建并处于
ACTIVE
状态后,请按照后续步骤使用 Autocheckpoint 运行 Pax。在多切片环境中的所有虚拟机上安装 Pax。
在 TPU 虚拟机上,在多切片环境中的所有 TPU 虚拟机上安装
jax[tpu]
和最新的paxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
使用适当的配置启动训练
以下示例展示了如何配置模型
LmCloudSpmd2B
在多切片环境中训练时使用的自动检查点。在运行训练脚本之前,请将 DCN_MESH_SHAPE 设置为 [2, 1, 1],如以下代码所示:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 4, 1] DCN_MESH_SHAPE = [2, 1, 1]
启动训练时,除了使用命令行标志 除此之外,还需要添加三个:
num_hosts
:主机的总数。在本例中,它为 2。host_index
:启动训练的主机的编号。因情况而异 从 0 到N-1
,其中N
是主机的总数。server_addr
:节点 0 的工作器 0 的 IP 地址,有一个未使用的 (例如 8476)。如需查找此信息,请对节点 0 的工作器 0 使用hostname -i
。
使用 Orbax 自动创建检查点
Autocheckpoint 功能不仅适用于 MaxText 或 Pax。任何框架 该服务器可以捕获 SIGTERM 信号并发起 检查点流程与 Autocheckpoint 提供的基础架构协同工作。 Orbax 是一个 通用实用程序库提供了这些功能。
如 Orbax 文档中所述,系统会默认为 orbax.checkpoint.CheckpointManager
用户启用这些功能。save
方法
每个步骤之后调用的这个函数会自动检查
事件即将到来,如果是,则保存检查点,即使步骤编号
不是 save_interval_steps
的倍数。
GitHub 文档
也说明了如何在保存
设置了自动检查点,并在用户代码中进行了修改。