Cloud TPU Autocheckpoint [公開プレビュー]

概要

従来、TPU VM でメンテナンスが必要な場合、ユーザーはチェックポイントの保存など、進行状況を保持するアクションを実行する時間を確保せずに、プロシージャがすぐに開始されます。これを図 1(a)に示します。

Autocheckpoint

図 1.Autocheckpoint 機能の図: (a)Autocheckpoint が存在しない場合、今後のメンテナンス イベントが発生すると、最後のチェックポイントからのトレーニングの進行状況が失われます。(b)Autocheckpoint を使用すると、今後のメンテナンス イベントの際に、最後のチェックポイント以降のトレーニングの進行状況を保持できます。

Autocheckpoint(図 1(b))を使用すると、メンテナンス イベントが発生したときにスケジュールが設定されていないチェックポイントを保存するようにコードを構成することで、トレーニングの進行状況を保持できます。メンテナンス イベントが発生すると、最後のチェックポイント以降の進行状況が自動的に保存されます。この機能は、単一スライスとマルチスライスの両方で機能します。

Autocheckpoint 機能は、SIGTERM をキャプチャしてチェックポイントを保存できるフレームワークで動作します。サポートされているフレームワークには、MaxTextPaxOrbax を使用した JAX があります。他のフレームワークのサポートは、利用可能になり次第、発表されます。

現時点では、Cloud TPU API で作成された TPU(v2-v4、v5e)のみがこの機能を使用できます。GKE での TPU のサポートは、利用可能になった時点で発表されます。

Autocheckpoint の使用

Autocheckpoint 機能はデフォルトで無効になっています。TPU またはキューに格納されたリソースを作成する場合は、TPU のプロビジョニング時に --autocheckpoint-enabled フラグを追加することで有効にできます。この機能を有効にすると、Cloud TPU はメンテナンス イベントの通知を受信したら、次の手順を実行します。

  1. TPU デバイスを使用してプロセスに送信された SIGTERM をキャプチャします。
  2. プロセスが終了するまで、または 5 分間が経過するまで、どちらかの時点で待機し、影響を受けるスライスでメンテナンスを実行します。

Autocheckpoint で使用されるインフラストラクチャは ML フレームワークに依存しません。SIGTERM シグナルをキャプチャしてチェックポインティング プロセスを開始できる場合は、ML フレームワークで Autocheckpoint をサポートできます。

アプリケーション コードでは、ML フレームワークの提供する Autocheckpoint 機能を有効にする必要があります。たとえば、Pax の場合、トレーニングの起動時にコマンドライン フラグを有効にします(Pax による Autocheckpoint クイックスタートをご覧ください)。SIGTERM は、フレームワークによって、スケジュールされていないチェックポイントが保存され、TPU VM が使用されなくなったときにメンテナンスが実施されます。

クイックスタート: MaxText での Autocheckpoint

MaxText は、「高パフォーマンスで任意にスケーリングできる、オープンソースで十分にテストされた LLM で、Cloud TPU を採用している純粋な Python/JAX で記述されています」。MaxText には、Autocheckpoint 機能の使用に必要なすべての設定が含まれています。

MaxText の README では、MaxText を大規模に実行するための 2 つの方法について説明します。

multihost_runner.py を使用する場合、必要な変更は、キューに入れられたリソースをプロビジョニングするときに autocheckpoint-enabled フラグを設定することだけです。multihost_job.py を使用する場合に必要となる変更は、ジョブの起動時に ENABLE_AUTOCHECKPOINT=true コマンドライン フラグを指定することだけです。

クイックスタート: シングル スライス上の Pax での Autocheckpoint

このセクションでは、単一のスライスで Pax に対して Autocheckpoint を設定して使用する方法の例を示します。適切な設定を行います。

  • チェックポイントは、メンテナンス イベントが発生すると保存されます。
  • Cloud TPU は、チェックポイントの保存後、影響を受ける TPU VM のメンテナンスを実行します。
  • Cloud TPU でメンテナンスが完了すると、通常どおり TPU VM を使用できます。
  1. TPU VM またはキューに格納されたリソースを作成する場合は、autocheckpoint-enabled フラグを使用します。

    次に例を示します。

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    
  2. Pax を単一スライスにインストールする

    Autocheckpoint 機能は、Pax バージョン 1.1.0 以降で動作します。TPU VM で、jax[tpu] と最新の paxml をインストールします。

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. 適切な構成でトレーニングを開始する

    次の例は、Autocheckpoint によってトリガーされたチェックポイントを Google Cloud Storage バケットに保存するように LmCloudSpmd2B モデルを構成する方法を示しています。

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
    

    コマンドに渡される 2 つのフラグに注意してください。

    • jax_fully_async_checkpoint: このフラグを有効にすると、orbax.checkpoint.AsyncCheckpointer が使用されます。AsyncCheckpointer クラスは、トレーニング スクリプトが SIGTERM シグナルを受信すると、チェックポイントを自動的に保存します。
    • exit_after_ondemand_checkpoint: このフラグを有効にすると、TPU プロセスは Autocheckpoint が正常に保存された後に終了し、メンテナンスが直ちに実行されます。このフラグを使用しない場合、チェックポイントが保存された後、トレーニングは続行されます。Cloud TPU は、必要なメンテナンスが実行される前にタイムアウトが発生するまで待機します(5 分)。

クイックスタート: マルチスライス上の Pax での Autocheckpoint

Autocheckpoint は、単一のスライスだけでなく、マルチスライスにも機能します。このセクションでは、マルチスライスで Autocheckpoint を使用するために必要な手順について詳しく説明します。

  1. キューに登録されたリソースの作成時に Autocheckpoint を指定します。

    マルチスライス環境をプロビジョニングできるのは、キューに登録されたリソース リクエストのみです。シングルスライスの場合と同様に、呼び出しで autocheckpoint-enabled フラグを使用してキュー内にリソースを作成します。

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud alpha compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    

    使用可能なすべてのオプションの詳細については、マルチスライス ユーザーガイドをご覧ください。キューに登録されたリソース リクエストが作成され、ACTIVE 状態になったら、次の手順に従って、Autocheckpoint を使用して Pax を実行します。

  2. マルチスライス環境内のすべての VM に Pax をインストールします。

    TPU VM では、マルチスライス環境のすべての TPU VM に jax[tpu] と最新の paxml をインストールします。

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. 適切な構成でトレーニングを開始する

    この例では、マルチスライス環境でトレーニングする際に、Autocheckpoint 用のモデル LmCloudSpmd2B を構成する方法を示しています。トレーニング スクリプトを実行する前に、以下に示すように DCN_MESH_SHAPE を [2, 1, 1] に設定します。

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]
    

    トレーニングを開始するときは、シングルスライスの場合で説明したコマンドライン フラグに加えて、次の 3 つが必要になります。

    • num_hosts: ホストの合計数。この場合は 2 です。
    • host_index: トレーニングを開始するホストのインデックス。0 から N-1 の範囲で変化します。ここで、N はホストの総数です。
    • server_addr: 未使用のポート(たとえば、8476)を含むノード 0 のワーカー 0 の IP アドレス。この情報を見つけるには、ノード 0 のワーカー 0 で hostname -i を使用します。

Orbax での Autocheckpoint

Autocheckpoint 機能は、MaxText または Pax に限定されません。SIGTERM シグナルをキャプチャしてチェックポインティング プロセスを開始できるフレームワークは、Autocheckpoint によって提供されるインフラストラクチャで動作します。Orbax(JAX ユーザー用の一般的なユーティリティ ライブラリを提供する名前空間)は、これらの機能を提供します。

Orbax のドキュメントで説明されているように、orbax.checkpoint.CheckpointManager のユーザーはこれらの機能がデフォルトで有効になっています。各ステップの後に呼び出される save メソッドは、メンテナンス イベントが差し迫っているかどうかを自動的に確認します。その場合は、ステップ番号が save_interval_steps の倍数でない場合でもチェックポイントを保存します。また、GitHub のドキュメントでは、ユーザーコードを変更して、Autocheckpoint を保存した後にトレーニングを終了する方法も示しています。