使用 Autocheckpoint 保留訓練進度
過去,當 TPU VM 需要維護時,程序會立即啟動,使用者沒有時間執行儲存檢查點等保留進度的動作。如圖 1(a) 所示。
圖 1. 自動檢查點功能插圖: (a) 如果沒有自動檢查點,即將進行維護作業時,系統會遺失上次檢查點的訓練進度。(b) 透過自動檢查點,系統會在即將進行維護作業時,保留上次檢查點後的訓練進度。
您可以透過自動檢查點 (圖 1(b)) 保留訓練進度,方法是設定程式碼,在發生維護事件時儲存非排程檢查點。發生維護事件時,系統會自動儲存上次檢查點後的進度。這項功能適用於單一切片和多切片。
自動檢查點功能適用於可擷取 SIGTERM 信號,並隨後儲存檢查點的架構。支援的架構包括:
使用 Autocheckpoint
自動檢查點功能預設為停用。建立 TPU 或要求排入佇列的資源時,您可以在佈建 TPU 時新增 --autocheckpoint-enabled
標記,啟用自動檢查點功能。啟用這項功能後,Cloud TPU 收到維護事件通知時,會執行下列步驟:
- 使用 TPU 裝置擷取傳送至程序的 SIGTERM 信號
- 等待程序結束或經過 5 分鐘 (以先到者為準)
- 對受影響的切片執行維護作業
Autocheckpoint 使用的基礎架構與 ML 架構無關。只要能擷取 SIGTERM 信號並啟動檢查點程序,任何機器學習架構都能支援自動檢查點。
在應用程式程式碼中,您需要啟用 ML 架構提供的自動檢查點功能。舉例來說,在 Pax 中,這表示啟動訓練時要啟用指令列標記。詳情請參閱 Pax 的 Autocheckpoint 快速入門導覽課程。 在幕後,架構會在收到 SIGTERM 信號時儲存非排程檢查點,而受影響的 TPU VM 會在 TPU 不再使用時進行維護。
快速入門導覽課程:使用 MaxText 自動檢查點
MaxText 是高效能、可任意擴充的開放原始碼 LLM,以純 Python/JAX 編寫,適用於 Cloud TPU,且經過完善測試。MaxText 包含使用自動檢查點功能所需的所有設定。
MaxText README
檔案說明瞭大規模執行 MaxText 的兩種方式:
- 使用
multihost_runner.py
,建議用於實驗 - 使用
multihost_job.py
,建議用於正式版
使用 multihost_runner.py
時,請在佈建佇列資源時設定 autocheckpoint-enabled
旗標,啟用自動檢查點。
使用 multihost_job.py
時,請在啟動工作時指定 ENABLE_AUTOCHECKPOINT=true
指令列旗標,啟用自動檢查點。
快速入門導覽課程:在單一切片上使用 Pax 進行自動檢查點
本節提供範例,說明如何使用 Pax 在單一切片上設定及使用 Autocheckpoint。完成適當設定後:
- 發生維護事件時,系統會儲存檢查點。
- 儲存檢查點後,Cloud TPU 會對受影響的 TPU VM 執行維護作業。
- Cloud TPU 完成維護作業後,您就能照常使用 TPU VM。
建立 TPU VM 或要求佇列資源時,請使用
autocheckpoint-enabled
標記。例如:
設定環境變數:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=zone-you-want-to-use export ACCELERATOR_TYPE=your-accelerator-type export RUNTIME_VERSION=tpu-ubuntu2204-base
在有效設定中設定專案 ID 和可用區:
gcloud config set project $PROJECT_ID gcloud config set compute/zone $ZONE
建立 TPU:
gcloud alpha compute tpus tpu-vm create $TPU_NAME \ --accelerator-type $ACCELERATOR_TYPE \ --version $RUNTIME_VERSION \ --autocheckpoint-enabled
使用 SSH 連線至 TPU:
gcloud compute tpus tpu-vm ssh $TPU_NAME
在單一切片上安裝 Pax
自動檢查點功能適用於 Pax 1.1.0 以上版本。在 TPU VM 上安裝
jax[tpu]
和最新版paxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
設定
LmCloudSpmd2B
模型。執行訓練指令碼前,請將ICI_MESH_SHAPE
變更為[1, 8, 1]
:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 8, 1]
使用適當的設定啟動訓練。
以下範例說明如何設定
LmCloudSpmd2B
模型,將 Autocheckpoint 觸發的檢查點儲存至 Cloud Storage 值區。將 your-storage-bucket 換成現有值區的名稱,或建立新值區。export JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py \ --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
請注意傳遞至指令的兩個旗標:
jax_fully_async_checkpoint
: 開啟這個標記後,系統就會使用orbax.checkpoint.AsyncCheckpointer
。 訓練指令碼收到 SIGTERM 訊號時,AsyncCheckpointer
類別會自動儲存檢查點。exit_after_ondemand_checkpoint
: 開啟這個標記後,TPU 程序會在成功儲存自動檢查點後結束,立即觸發執行維護作業。如果未使用這個標記,訓練作業會在儲存檢查點後繼續進行,而 Cloud TPU 會等待逾時 (5 分鐘),然後執行必要維護作業。
使用 Orbax 自動檢查點
自動檢查點功能不限於 MaxText 或 Pax。只要架構可以擷取 SIGTERM 信號並啟動檢查點程序,就能與 Autocheckpoint 提供的基礎架構搭配使用。Orbax 命名空間提供這些功能,可為 JAX 使用者提供常見的公用程式庫。
如 Orbax 文件所述,這些功能預設會為 orbax.checkpoint.CheckpointManager
使用者啟用。每個步驟後呼叫的 save
方法會自動檢查維護事件是否即將發生,如果會,即使步驟編號不是 save
的倍數,也會儲存檢查點。save_interval_steps
GitHub 文件也說明如何修改使用者程式碼,在儲存 Autocheckpoint 後結束訓練。