Cloud TPU Autocheckpoint [公開プレビュー]

概要

これまでのように、TPU VM でメンテナンスが必要な場合、ユーザーがチェックポイントの保存などの進行状況維持アクションを実行するための時間を確保することなく、手順がすぐに開始されます。これを図 1(a)に示します。

Autocheckpoint

図 1.Autocheckpoint 機能の図:(a)Autocheckpoint がない場合、今後のメンテナンス イベントが発生すると、最後のチェックポイントからのトレーニングの進行状況が失われます。(b)Autocheckpoint を使用すると、今後のメンテナンス イベントの際に、最後のチェックポイント以降のトレーニングの進行状況を保持できます。

Autocheckpoint(図 1(b))を使用すると、メンテナンス イベントが発生したときにスケジュールが設定されていないチェックポイントを保存するようにコードを構成することで、トレーニングの進行状況を保持できます。メンテナンス イベントが発生すると、最後のチェックポイント以降の進行状況が自動的に保存されます。この機能は、シングル スライスとマルチスライスの両方で機能します。

Autocheckpoint 機能は、SIGTERM をキャプチャしてチェックポイントを保存できるフレームワークで動作します。サポートされているフレームワークには、MaxTextPaxOrbax を使用する JAX が含まれています。追加のフレームワークのサポートは、利用可能になり次第、発表されます。

現時点で、この機能を使用できるのは、Cloud TPU API で作成された TPU(v2-v4 と v5e)のみです。GKE での TPU のサポートは、利用可能になり次第、発表されます。

Autocheckpoint の使用

Autocheckpoint 機能はデフォルトで無効になっています。TPU またはキューに格納されたリソースを作成する場合は、TPU をプロビジョニングするときに --autocheckpoint-enabled フラグを追加することで有効にできます。この機能を有効にすると、Cloud TPU はメンテナンス イベントの通知を受信したら、次の手順を実行します。

  1. TPU デバイスを使用してプロセスに送信された SIGTERM をキャプチャします。
  2. プロセスが終了するまで待機するか、5 分が経過するかのいずれか先の時点で待機し、影響を受けたスライスのメンテナンスを行います。

Autocheckpoint で使用されるインフラストラクチャは、ML フレームワークに依存しないことに注意してください。SIGTERM シグナルをキャプチャしてチェックポインティング プロセスを開始できれば、どの ML フレームワークでも Autocheckpoint をサポートできます。

アプリケーション コードで、ML フレームワークが提供する Autocheckpoint 機能を有効にする必要があります。たとえば、Pax の場合、トレーニングの起動時にコマンドライン フラグを有効にします(Pax による Autocheckpoint クイックスタートをご覧ください)。バックグラウンドで、SIGTERM を受信すると、スケジュールが設定されていないチェックポイントが保存され、TPU VM が使用されなくなった場合、影響を受ける TPU VM がメンテナンスに入ります。

クイックスタート: MaxText での Autocheckpoint

MaxText は、「Cloud TPU をターゲットとする純粋な Python/JAX で記述された、高パフォーマンスで任意にスケーラブルなオープンソースで十分にテストされた LLM」です。MaxText には、Autocheckpoint 機能を使用するために必要なすべての設定が含まれています。

MaxText の README には、MaxText を大規模に実行するための 2 つの方法が説明されています。

multihost_runner.py を使用する場合、必要な変更は、キューに入れられたリソースをプロビジョニングするときに autocheckpoint-enabled フラグを設定することです。multihost_job.py を使用する場合、必要な変更は、ジョブの起動時に ENABLE_AUTOCHECKPOINT=true コマンドライン フラグを指定することのみです。

クイックスタート: シングル スライス上の Pax での Autocheckpoint

このセクションでは、単一のスライス上の Pax で Autocheckpoint を設定して使用する方法の例を示します。適切な設定:

  • メンテナンス イベントが発生すると、チェックポイントが保存されます。
  • Cloud TPU は、チェックポイントの保存後に、影響を受ける TPU VM のメンテナンスを実行します。
  • Cloud TPU のメンテナンスが完了すると、通常どおり TPU VM を使用できます。
  1. TPU VM またはキューに格納されたリソースを作成するときに、autocheckpoint-enabled フラグを使用します。

    次に例を示します。

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
  2. 単一のスライスに Pax をインストールする

    Autocheckpoint 機能は、Pax バージョン 1.1.0 以降で動作します。TPU VM で、jax[tpu] と最新の paxml をインストールします。

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. 適切な構成でトレーニングを開始する

    次の例は、Autocheckpoint によってトリガーされたチェックポイントを Google Cloud Storage バケットに保存するように、LmCloudSpmd2B モデルを構成する方法を示しています。

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    コマンドに渡される 2 つのフラグに注意してください。

    • jax_fully_async_checkpoint: このフラグを有効にすると、orbax.checkpoint.AsyncCheckpointer が使用されます。AsyncCheckpointer クラスは、トレーニング スクリプトが SIGTERM シグナルを受信すると、チェックポイントを自動的に保存します。
    • exit_after_ondemand_checkpoint: このフラグをオンにすると、Autocheckpoint が正常に保存された後に TPU プロセスが終了し、メンテナンスがすぐに実行されます。このフラグを使用しない場合、チェックポイントの保存後にトレーニングが続行され、Cloud TPU は必要なメンテナンスを実行する前にタイムアウトが発生するまで待機します(5 分)。

クイックスタート: マルチスライス上の Pax での Autocheckpoint

Autocheckpoint は、単一のスライスだけでなく、マルチスライスにも機能します。このセクションでは、マルチスライスで Autocheckpoint を使用するために必要な手順について詳しく説明します。

  1. キューに格納されたリソースの作成中に Autocheckpoint を指定します。

    マルチスライス環境は、キューに入れられたリソース リクエストを介してのみプロビジョニングできます。 シングルスライスの場合と同様に、呼び出しで autocheckpoint-enabled フラグを使用して、キューに格納されたリソースを作成します。

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled

    使用可能なすべてのオプションの詳細については、マルチスライス ユーザーガイドをご覧ください。キューに格納されたリソース リクエストが作成され、ACTIVE 状態になったら、次の手順で Autocheckpoint を使用して Pax を実行します。

  2. マルチスライス環境のすべての VM に Pax をインストールします。

    TPU VM で、マルチスライス環境のすべての TPU VM に jax[tpu] と最新の paxml をインストールします。

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. 適切な構成でトレーニングを開始する

    この例では、マルチスライス環境でトレーニングする際に、Autocheckpoint 用のモデル LmCloudSpmd2B を構成する方法を示しています。トレーニング スクリプトを実行する前に、次のコードに示すように DCN_MESH_SHAPE を [2, 1, 1] に設定します。

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]

    トレーニングを起動するときは、シングル スライスの場合で説明したコマンドライン フラグに加えて、さらに 3 つのフラグが必要です。

    • num_hosts: ホストの合計数。この場合は 2 です。
    • host_index: トレーニングを開始するホストのインデックス。0 から N-1 まで変動します。ここで、N はホストの合計数です。
    • server_addr: ノード 0 のワーカー 0 の IP アドレス(未使用のポートを含む)(例: 8476)。この情報を確認するには、ノード 0 のワーカー 0 で hostname -i を使用します。

Orbax での Autocheckpoint

Autocheckpoint 機能は、MaxText または Pax に限定されません。SIGTERM シグナルをキャプチャしてチェックポインティング プロセスを開始できるフレームワークは、Autocheckpoint によって提供されるインフラストラクチャと連携します。JAX ユーザーに共通のユーティリティ ライブラリを提供する名前空間である Orbax が、これらの機能を提供します。

Orbax のドキュメントで説明したように、これらの機能は orbax.checkpoint.CheckpointManager のユーザーに対してデフォルトで有効になっています。各ステップの後に呼び出される save メソッドは、メンテナンス イベントが差し迫っているかどうかを自動的に確認します。その場合は、ステップ番号が save_interval_steps の倍数でない場合でもチェックポイントを保存します。また、GitHub のドキュメントでは、ユーザーコードを変更して、Autocheckpoint を保存した後にトレーニングを終了する方法も示しています。