使用 Autocheckpoint 保留訓練進度

過去,當 TPU VM 需要維護時,程序會立即啟動,使用者沒有時間執行儲存檢查點等保留進度的動作。如圖 1(a) 所示。

圖表:顯示主機維護作業的影響 (有和沒有自動檢查點)

圖 1. 自動檢查點功能插圖: (a) 如果沒有自動檢查點,即將進行維護時,系統會遺失上次檢查點的訓練進度。(b) 透過自動檢查點,系統會在即將進行維護作業時,保留上次檢查點後的訓練進度。

您可以透過自動檢查點 (圖 1(b)) 保留訓練進度,方法是設定程式碼,在發生維護事件時儲存非排程檢查點。發生維護事件時,系統會自動儲存上次檢查點後的進度。這項功能適用於單一切片和多切片。

自動檢查點功能適用於可擷取 SIGTERM 信號,並隨後儲存檢查點的架構。支援的架構包括:

使用 Autocheckpoint

自動檢查點功能預設為停用。建立 TPU 或要求排入佇列的資源時,您可以在佈建 TPU 時新增 --autocheckpoint-enabled 標記,啟用自動檢查點功能。啟用這項功能後,Cloud TPU 收到維護事件通知時,會執行下列步驟:

  1. 擷取傳送至使用 TPU 裝置之程序的 SIGTERM 信號
  2. 等待程序結束或經過 5 分鐘 (兩者取其先)
  3. 對受影響的切片執行維護作業

Autocheckpoint 使用的基礎架構與 ML 架構無關。只要能擷取 SIGTERM 信號並啟動檢查點程序,任何機器學習架構都能支援自動檢查點。

在應用程式程式碼中,您需要啟用 ML 架構提供的自動檢查點功能。舉例來說,在 Pax 中,這表示啟動訓練時要啟用指令列標記。詳情請參閱 Pax 的 Autocheckpoint 快速入門導覽課程。 在幕後,架構會在收到 SIGTERM 訊號時儲存非排程檢查點,而受影響的 TPU VM 會在 TPU 不再使用時進行維護。

快速入門導覽課程:使用 MaxText 自動檢查點

MaxText 是高效能、可任意擴充的開放原始碼 LLM,以純 Python/JAX 編寫,適用於 Cloud TPU,且經過完善測試。MaxText 包含使用自動檢查點功能所需的所有設定。

MaxText README 檔案說明瞭兩種大規模執行 MaxText 的方式:

使用 multihost_runner.py 時,請在佈建佇列資源時設定 autocheckpoint-enabled 旗標,啟用自動檢查點。

使用 multihost_job.py 時,請在啟動工作時指定 ENABLE_AUTOCHECKPOINT=true 指令列旗標,啟用自動檢查點。

快速入門導覽課程:在單一切片上使用 Pax 進行自動檢查點

本節提供範例,說明如何使用 Pax 在單一切片上設定及使用 Autocheckpoint。完成適當設定後:

  • 發生維護事件時,系統會儲存檢查點。
  • 儲存檢查點後,Cloud TPU 會對受影響的 TPU VM 執行維護作業。
  • Cloud TPU 完成維護作業後,您就能照常使用 TPU VM。
  1. 建立 TPU VM 或要求佇列資源時,請使用 autocheckpoint-enabled 標記。

    例如:

    1. 設定環境變數:

      export PROJECT_ID=your-project-id
      export TPU_NAME=your-tpu-name
      export ZONE=zone-you-want-to-use
      export ACCELERATOR_TYPE=your-accelerator-type
      export RUNTIME_VERSION=tpu-ubuntu2204-base

      環境變數說明

      變數 說明
      PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
      TPU_NAME TPU 的名稱。
      ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
      ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
      RUNTIME_VERSION Cloud TPU 軟體版本

    2. 在有效設定中設定專案 ID 和可用區:

      gcloud config set project $PROJECT_ID
      gcloud config set compute/zone $ZONE
    3. 建立 TPU:

      gcloud alpha compute tpus tpu-vm create $TPU_NAME \
          --accelerator-type $ACCELERATOR_TYPE \
          --version $RUNTIME_VERSION \
          --autocheckpoint-enabled
  2. 使用 SSH 連線至 TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME
    
  3. 在單一切片上安裝 Pax

    自動檢查點功能適用於 Pax 1.1.0 以上版本。在 TPU VM 上安裝 jax[tpu] 和最新版 paxml

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. 設定 LmCloudSpmd2B 模型。執行訓練指令碼前,請將 ICI_MESH_SHAPE 變更為 [1, 8, 1]

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 8, 1]
  5. 使用適當的設定啟動訓練。

    以下範例說明如何設定 LmCloudSpmd2B 模型,將 Autocheckpoint 觸發的檢查點儲存至 Cloud Storage 值區。將 your-storage-bucket 換成現有值區的名稱,或建立新值區

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    請注意傳遞至指令的兩個旗標:

    • jax_fully_async_checkpoint: 開啟這個標記後,系統會使用 orbax.checkpoint.AsyncCheckpointer。 訓練指令碼收到 SIGTERM 訊號時,AsyncCheckpointer 類別會自動儲存檢查點。
    • exit_after_ondemand_checkpoint: 開啟這個標記後,TPU 程序會在成功儲存自動檢查點後結束,立即觸發執行維護作業。如果沒有使用這個標記,訓練作業會在儲存檢查點後繼續進行,而 Cloud TPU 會等待逾時 (5 分鐘),然後執行必要維護作業。

使用 Orbax 自動檢查點

自動檢查點功能不限於 MaxText 或 Pax。只要架構可以擷取 SIGTERM 信號並啟動檢查點程序,就能與 Autocheckpoint 提供的基礎架構搭配使用。Orbax 命名空間提供這些功能,可為 JAX 使用者提供常見的公用程式庫。

Orbax 文件所述,這些功能預設會為 orbax.checkpoint.CheckpointManager 使用者啟用。每個步驟後呼叫的 save 方法會自動檢查是否即將發生維護事件,如果是,即使步驟編號不是 save 的倍數,也會儲存檢查點。save_interval_stepsGitHub 文件也說明如何修改使用者程式碼,讓訓練作業在儲存 Autocheckpoint 後結束。