Cloud TPU マルチスライスの概要
Cloud TPU マルチスライスは、単純なデータ並列処理で、単一の Pod 内、または複数の POD 内のスライスで、トレーニング ジョブが TPU スライスを使用することができるようにする、フルスタック パフォーマンス スケーリング テクノロジーです。TPU v4 チップでは、トレーニング ジョブは 1 回の実行で 4, 096 個を超えるチップを使用できます。4, 096 チップ未満を必要とするトレーニング ジョブの場合、単一スライスが最高のパフォーマンスを発揮します。ただし、複数の小さなスライスがより簡単に利用できるため、マルチスライスを小さなスライスで使用する場合、起動時間が短縮されます。
マルチスライス構成にデプロイすると、各スライス内の TPU チップがチップ間相互接続(ICI)を介して通信します。異なるスライス内の TPU チップは、CPU(ホスト)にデータを転送することで通信し、CPU はデータセンター ネットワーク(DCN)を介してデータを転送します。
スライス間 DCN 通信を実装するためにデベロッパーがコードを記述することはありません。XLA コンパイラが、そのコードを生成し、最大限のパフォーマンスが発揮できるようにコンピューティングと通信をオーバーラップします。
コンセプト
- アクセラレータ タイプ
- マルチスライスを含む各 TPU スライスの形状。マルチスライス リクエスト内の各スライスは、同じアクセラレータ タイプです。アクセラレータ タイプは、TPU タイプ(v4 または v5e)と TensorCore の数で構成されます。たとえば、
v4-128
は 128 個の TensorCore を使用する TPU v4 を指定します。 - 自動修復
- スライスにメンテナンス イベント、プリエンプション、またはハードウェアの障害が発生すると、Cloud TPU が新しいスライスを作成します。新しいスライスを作成するのに十分なリソースがない場合はまれですが、ハードウェアが使用可能になるまで作成は完了しません。新しいスライスを作成すると、マルチスライス環境内の他のすべてのスライスが再起動され、トレーニングを続行できます。適切に構成された起動スクリプトを使用すると、ユーザーの介入なしに、トレーニング スクリプトが自動的に再起動され、最新のチェックポイントから読み込みと再開が行われます。
- データセット
- モデルがトレーニングまたは推論に使用するデータ。
- データセンター ネットワーキング(DCN)
- マルチスライス構成で TPU スライスを接続する、高レイテンシ、低スループット(ICI と比較して)ネットワーク。
- ギャング スケジューリング
- すべての TPU スライスを同時にプロビジョニングすると、すべてのスライスが正常にプロビジョニングされるか、まったくプロビジョニングされないことを保証します。
- ホスト
- ホストは、VM を実行する物理コンピュータです。ホストで一度に実行できる VM は最大 4 つです。各 VM には専用の TPU があります。
- 推論
- 事前トレーニング済みの ML モデルをホストに読み込み、データの予測を行います。
- インターチップ相互接続(ICI)
- TPU Pod 内で TPU を接続する高速かつ低レイテンシの内部リンク。
- マルチスライス
- DCN を介して通信できる 2 つ以上の TPU チップスライス。
- ノード
- マルチスライスのコンテキストでは、ノードは単一の TPU スライスを指します。 マルチスライスの各 TPU スライスにはノード ID が割り当てられます。
- ポッド
- 専用の ICI ネットワーク インターフェースで接続された TPU チップのコレクション。Pod を使用すると、処理負荷を複数の TPU に分散できます。
- キューに登録されたリソース(QR)
- シングルスライスまたはマルチスライス TPU 環境に対するリクエストをキューに追加して管理するために使用される TPU リソースの表現。
- 起動スクリプト
- VM が起動または再起動されるたびに実行される標準の Compute Engine 起動スクリプト。マルチスライスの場合、QR 作成リクエストで指定します。Cloud TPU 起動スクリプトの詳細については、TPU リソースの管理をご覧ください。
- TPU スライス
- TPU チップで構成される TPU Pod の論理項。スライス内のすべてのチップは、ICI ネットワークを使用して相互に通信します。
- TPU VM
- 基盤となる TPU にアクセスできる Linux を実行している仮想マシン。 v4 TPU の場合、各 TPU VM は 4 つのチップに直接アクセスできます。TPU VM は、ワーカーとも呼ばれます。
- Tensor
- 機械学習モデルの多次元データを表すために使用されるデータ構造。
- Tensor processing unit (TPU)
- Google が社内で開発した ML アクセラレーション チップ。行列乗算などの主要な ML タスクに対して、高速で電力効率の高いコンピューティングを提供するように設計されています。
- Cloud TPU の容量のタイプ
TPU は、さまざまなタイプの容量から作成できます(TPU の料金の仕組みの使用オプションを参照)。
- 予約: 予約を使用するには、Google との予約契約が必要です。リソースを作成する際は
--reserved
フラグを使用します。 - Spot: Spot VM を使用するプリエンプティブルの割り当てを対象にします。優先度の高いジョブに対するリクエストのスペースを確保するために、リソースがプリエンプトされる場合があります。リソースを作成する際は
--spot
フラグを使用します。 - オンデマンド: 予約を必要とせずプリエンプトされない、オンデマンド割り当てを対象にします。TPU リクエストは、Cloud TPU が提供するオンデマンド割り当てキューに追加されます。リソースの可用性は保証されません。デフォルトで選択されています。フラグは必要ありません。
- 予約: 予約を使用するには、Google との予約契約が必要です。リソースを作成する際は
使ってみる
TPU を初めて使用する場合は、まず Google Cloud CLI をインストールし、Cloud TPU 環境を設定します。マルチスライスを使用するには、TPU リソースをキューに格納されたリソースとして管理する必要があります。
既存の TPU v4 ユーザーで、予約がある場合は、予約を新しい予約システムに移行する必要があります。詳しくは、Google Cloud アカウント担当者にお問い合わせください。
入門例
このチュートリアルでは、MaxText GitHub リポジトリのコードを使用します。MaxText は、Python と Jax で記述された、高性能で任意に拡張可能なオープンソースの十分にテストされた基本 LLM です。MaxText は、Cloud TPU で効率的にトレーニングするように設計されています。
shardings.py
のコードは、さまざまな並列処理オプションをテストする際に役立つように設計されています。たとえば、データ並列処理、完全にシャーディングされたデータ並列処理(FSDP)、テンソル並列処理などです。コードは、単一スライスからマルチスライス環境にスケーリングされます。
ICI 並列処理
ICI は、単一スライスの TPU を接続する高速相互接続を指します。ICI シャーディングは、スライス内のシャーディングに対応しています。shardings.py
には、次の 3 つの ICI 並列処理パラメータがあります。
ici_data_parallelism
ici_fsdp_parallelism
ici_tensor_parallelism
これらのパラメータに指定する値によって、各並列化メソッドのシャードの数が決まります。
これらの入力は、ici_data_parallelism * ici_fsdp_parallelism * ici_tensor_parallelism
がスライス内のチップの数と等しくなるように制約する必要があります。
次の表に、v4 ~ 8 で使用可能な 4 つのチップの ICI 並列処理のユーザー入力の例を示します。
ici_data_parallelism | ici_fsdp_parallelism | ici_tensor_parallelism | |
4 ウェイ FSDP | 1 | 4 | 1 |
4 方向 Tensor 並列処理 | 1 | 1 | 4 |
双方向 FSDP + 双方向 TensorFlow 並列処理 | 1 | 2 | 2 |
ICI ネットワークは十分高速なため、ほぼ常にデータ並列処理よりも FSDP が優先されることから、ほとんどの場合 ici_data_parallelism
を 1 のままにする必要があります。
この例では、JAX を使用して Cloud TPU VM で計算を実行するなど、単一の TPU スライスでコードを実行することに精通していることを前提としています。この例は、単一のスライスで shardings.py
を実行する方法を示しています。
環境を設定します。
$ gcloud auth login $ gcloud config set project your-project-id $ gcloud config set compute/zone your-zone
gcloud
の SSH 認証鍵を作成します。パスワードは空白のままにすることをおすすめします(次のコマンドの実行後に 2 回 Enter を押します)。google_compute_engine
ファイルがすでに存在しているというメッセージが表示された場合は、既存のバージョンを置き換えます。$ ssh-keygen -f ~/.ssh/google_compute_engine
TPU をプロビジョニングします。
gcloud
$ gcloud compute tpus queued-resources \ create YOUR_QR_ID \ --accelerator-type your-accelerator-type \ --runtime-version tpu-ubuntu2204-base \ --node-id qr-id \ [--reserved |--spot]
コマンドフラグの説明
YOUR_QR_ID
- QR リクエストを識別するユーザー定義文字列。
accelerator-type
- アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU のバージョンごとにサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
runtime-version
- Cloud TPU ソフトウェアのバージョン。
node-id
- QR リクエストに応答して作成される TPU リソースの ID。
reserved
- スライスの作成時に予約を使用します。
spot
- スライスの作成時に Spot VM を使用します。
Google Cloud CLI では、タグなどの QR コードの作成オプションはサポートされていません。詳細については、QR コードを作成するをご覧ください。
Console
Google Cloud コンソールの [TPU] ページに移動します。
[TPU を作成] をクリックします。
[名前] フィールドに、TPU の名前を入力します。
[ゾーン] ボックスで、TPU を作成するゾーンを選択します。
[TPU タイプ] ボックスで、アクセラレータ タイプを選択します。アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
[TPU ソフトウェア バージョン] ボックスで、ソフトウェア バージョンを選択します。Cloud TPU VM の作成時には、インストールされる TPU ランタイム バージョンが TPU ソフトウェア バージョンによって指定されます。詳細については、TPU VM イメージをご覧ください。
[キューイングを有効にする] トグルをクリックします。
[キューに登録されたリソース名] フィールドに、キューに登録されたリソース リクエストの名前を入力します。
[作成] をクリックして、キューに格納されたリソース リクエストを作成します。
キューに格納されたリソースが
ACTIVE
状態になるまで待ちます。これは、ワーカーノードがREADY
状態であることを意味します。キューに入れられたリソースのプロビジョニングが開始されると、キューに入れられたリソースのサイズによっては、完了するまでに 1 ~ 5 分かかることがあります。キューに登録されているリソース リクエストのステータスは、gcloud CLI または Google Cloud コンソールを使用して確認できます。gcloud
$ gcloud compute tpus queued-resources \ list --filter=YOUR_QR_ID
Console
Google Cloud コンソールの [TPU] ページに移動します。
[キューに登録されているリソース] タブをクリックします。
キューに登録されているリソース リクエストの名前をクリックします。
v4-8 スライスには単一の TPU VM があります。SSH を使用して TPU VM に接続します。
$ gcloud compute tpus tpu-vm ssh YOUR_QR_ID
MaxText(
shardings.py
を含む)のクローンを TPU VM に作成します。MaxText リポジトリ ディレクトリでセットアップ スクリプトを実行して、TPU スライスに JAX などの依存関係をインストールします。このスクリプトのセットアップには数分かかります。
$ bash setup.sh
次のコマンドを実行して、TPU スライスで
shardings.py
を実行します。$ python3 pedagogical_examples/shardings.py \ --ici_fsdp_parallelism 4 \ --batch_size 131072 \ --embedding_dimension 2048
結果はログで確認できます。TPU は約 260 TFLOP/秒、または 90%以上の FLOP 使用率を達成する必要があります。この場合、TPU の高帯域幅メモリ(HBM)に収まるほぼ最大のバッチを選択しています。
ICI で他のシャーディング戦略を試すこともできます。たとえば、次の組み合わせを試すことができます。
$ python3 pedagogical_examples/shardings.py \ --ici_tensor_parallelism 4 \ --batch_size 131072 \ --embedding_dimension 2048
完了したら、キューに格納されたリソースと TPU スライスを削除します。これらのクリーンアップ手順は、スライスを設定した環境から実行する必要があります(まず
exit
を実行して SSH セッションを終了します)。削除が完了するまでに 2 ~ 5 分かかります。gcloud CLI を使用している場合は、オプションの--async
フラグを使用して、このコマンドをバックグラウンドで実行できます。gcloud
$ gcloud compute tpus queued-resources delete YOUR_QR_ID --force (--async)
Console
Google Cloud コンソールの [TPU] ページに移動します。
[キューに登録されているリソース] タブをクリックします。
キューに登録されているリソース リクエストの横にあるチェックボックスをオンにします。
[
削除] をクリックします。
DCN 並列処理を使用したマルチスライス シャーディング
shardings.py
スクリプトは、データ並列処理の各タイプのシャード数に対応する、DCN 並列処理を指定する 3 つのパラメータを受け入れます。
- dcn_data_parallelism
- dcn_fsdp_parallelism
- dcn_tensor_parallelism
これらのパラメータの値は、dcn_data_parallelism * dcn_fsdp_parallelism * dcn_tensor_parallelism
がスライス数と等しくなるように制約する必要があります。
2 つのスライスの場合は、--dcn_data_parallelism = 2
を使用します。
dcn_data_parallelism | dcn_fsdp_parallelism | dcn_tensor_parallelism | スライス数 | |
双方向データ並列処理 | 2 | 1 | 1 | 2 |
DCN はこのようなシャーディングには適していないため、dcn_tensor_parallelism
は常に 1
に設定する必要があります。v4 チップの一般的な LLM ワークロードでは、dcn_fsdp_parallelism
も 1
に設定する必要があるため、dcn_data_parallelism
をスライス数に設定する必要がありますが、これはアプリケーションによって異なります。
スライスの数を増やすと(スライスのサイズとスライスごとのバッチ数が一定に保たれていると想定)、データの並列処理の量が増えます。
マルチスライス環境で shardings.py
を実行する
マルチスライス環境で shardings.py
を実行するには、multihost_runner.py
を使用するか、各 TPU VM で shardings.py
を実行します。ここでは multihost_runner.py
を使用します。
次の手順は、MaxText リポジトリからのはじめに: 複数のスライスでの簡単なテストの手順と似ています。ただしここでは、train.py
のより複雑な LLM の代わりに shardings.py
を実行します。
multihost_runner.py
ツールは、同じ TPU を繰り返し再利用して、簡単なテスト用に最適化されています。multihost_runner.py
スクリプトは長時間 SSH 接続に依存するため、長時間実行ジョブにはおすすめしません。
長時間のジョブ(数時間または数日間など)を実行する場合は、multihost_job.py
を使用することをおすすめします。
このチュートリアルでは、multihost_runner.py
スクリプトを実行するマシンを示すためにランナーという用語を使用します。スライスを構成する TPU VM を示すためにワーカーという用語を使用します。multihost_runner.py
は、ローカルマシン、またはスライスと同じプロジェクト内の任意の Compute Engine VM 上で実行できます。 ワーカーでの multihost_runner.py
の実行はサポートされていません。
multihost_runner.py
は、SSH を使用して TPU ワーカーに自動的に接続します。
この例では、2 つの v4-16 スライス(合計 4 つの VM と 16 個の TPU チップ)で shardings.py
を実行します。サンプルを変更し、より多くの TPU で実行できるようにできます。
環境を設定する
ランナーマシンに MaxText のクローンを作成します。
リポジトリ ディレクトリに移動します。
gcloud
の SSH 認証鍵を作成します。空白のパスワードを残すことをおすすめします(次のコマンドの実行後に 2 回 Enter を押します)。google_compute_engine
ファイルがすでに存在しているというメッセージが表示された場合は、既存のバージョンを保持しないを選択します。$ ssh-keygen -f ~/.ssh/google_compute_engine
環境変数を追加して、TPU スライス数を
2
に設定します。$ export SLICE_COUNT=2
queued-resources create
コマンドまたは Google Cloud コンソールを使用して、マルチスライス環境を作成します。gcloud
次のコマンドは、v4 マルチスライス TPU を作成する方法を示しています。v5e を使用するには、v5e
accelerator-type
(v5litepod-16
など)と v5eruntime-version
(v2-alpha-tpuv5-lite
)を指定します。$ gcloud compute tpus queued-resources \ create YOUR_QR_ID \ --accelerator-type=your-accelerator-type \ --runtime-version=tpu-vm-runtime-version \ --node-count=node-count \ --node-prefix=YOUR_QR_ID \ [--reserved|--spot]
コマンドフラグの説明
YOUR_QR_ID
- QR リクエストを識別するユーザー定義文字列。
accelerator-type
- アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。マルチスライスは、Cloud TPU v4 以降の TPU バージョンでのみサポートされます。TPU のバージョンごとにサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
runtime-version
- Cloud TPU ソフトウェアのバージョン。
node-count
- 作成するスライス数。
node-prefix
- 各スライスの名前の生成に使用される接頭辞。各スライスの接頭辞に番号が追加されます。たとえば、
node-prefix
をmySlice
に設定すると、スライスの名前はmySlice-0
、mySlice-1
となり、スライスごとに数字が続きます。 reserved
- スライスの作成時に予約を使用します。
spot
- スライスの作成時に Spot VM を使用します。
Console
Google Cloud コンソールの [TPU] ページに移動します。
[TPU を作成] をクリックします。
[名前] フィールドに、TPU の名前を入力します。
[ゾーン] ボックスで、TPU を作成するゾーンを選択します。
[TPU タイプ] ボックスで、アクセラレータ タイプを選択します。アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。マルチスライスは、Cloud TPU v4 以降の TPU バージョンでのみサポートされています。TPU の各バージョンでサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
[TPU ソフトウェア バージョン] ボックスで、ソフトウェア バージョンを選択します。Cloud TPU VM の作成時には、インストールされる TPU ランタイム バージョンが TPU ソフトウェア バージョンによって指定されます。詳細については、TPU VM イメージをご覧ください。
[キューイングを有効にする] トグルをクリックします。
[キューに登録されたリソース名] フィールドに、キューに登録されたリソース リクエストの名前を入力します。
[マルチスライス TPU にする] チェックボックスをオンにします。
[スライス数] フィールドに、作成するスライス数を入力します。
[作成] をクリックして、キューに格納されたリソース リクエストを作成します。
キューに登録されたリソースのプロビジョニングが開始されると、キューに登録されたリソースのサイズによっては、完了までに最大 5 分かかることがあります。キューに格納されたリソースが
ACTIVE
状態になるまで待ちます。キューに登録されているリソース リクエストのステータスは、gcloud CLI または Google Cloud コンソールを使用して確認できます。gcloud
$ gcloud compute tpus queued-resources list \ --filter=YOUR_QR_ID
次のような出力が表示されます。
NAME ZONE NODE_COUNT ACCELERATOR_TYPE STATE ... que-res-id us-central2-b 4 v4-16 ACTIVE ...
Console
Google Cloud コンソールの [TPU] ページに移動します。
[キューに登録されているリソース] タブをクリックします。
キューに登録されているリソース リクエストの名前をクリックします。
QR のステータスが 15 分以上
WAITING_FOR_RESOURCES
またはPROVISIONING
の状態になっている場合は、Google Cloud アカウント担当者にお問い合わせください。依存関係をインストールします。
$ python3 multihost_runner.py \ --TPU_PREFIX=YOUR_QR_ID \ --COMMAND="bash setup.sh"
multihost_runner.py
を使用して各ワーカーでshardings.py
を実行します。$ python3 multihost_runner.py \ --TPU_PREFIX=YOUR_QR_ID \ --COMMAND="python3 pedagogical_examples/shardings.py \ --dcn_data_parallelism $SLICE_COUNT \ --ici_fsdp_parallelism 8 \ --batch_size 131072 \ --embedding_dimension 2048"
ログファイルには、1 秒あたり約 230 TFLOP のパフォーマンスが表示されます。
完了したら、TPU とキューに格納されたリソースをクリーンアップします。削除が完了するまでに 2 ~ 5 分かかります。gcloud CLI を使用している場合は、オプションの
--async
フラグを使用して、このコマンドをバックグラウンドで実行できます。
ワークロードをマルチスライスにスケーリングする
マルチスライス環境でモデルを実行する前に、次のコードを変更します。
- メッシュの作成時には、jax.experimental.mesh_utils.create_device_mesh ではなく jax.experimental.mesh_utils.create_hybrid_device_mesh を使用します。
マルチスライスに移行する際に必要なコード変更は、これらだけです。高いパフォーマンスを実現するには、DCN をデータ並列に、完全にシャーディングされたデータ並列軸またはパイプライン並列軸にマッピングする必要があります。パフォーマンスに関する考慮事項とシャーディング戦略の詳細については、最大パフォーマンスのためのマルチスライスを使用したシャーディングをご覧ください。
コードがすべてのデバイスにアクセスできることを確認するには、len(jax.devices())
がマルチスライス環境のチップ数と等しいことを表明します。たとえば、v4-16
の 4 つのスライスを使用している場合、スライスあたり 8 個のチップ * 4 つのスライスがあるため、len(jax.devices())
は 32 を返します。
マルチスライス環境のスライスサイズの選択
速度を線形的に向上させるには、既存のスライドと同じサイズの新しいスライドを追加します。たとえば、v4-512
スライスを使用する場合、マルチスライスでは、2 番目の v4-512
スライスを追加してグローバル バッチサイズを 2 倍にすることで、約 2 倍のパフォーマンスを実現できます。詳細については、最大パフォーマンスのためのマルチスライスを使用したシャーディングをご覧ください。
複数のスライスでのジョブの実行
マルチスライス環境でカスタム ワークロードを実行するには、次の 3 つの方法があります。
- 試験運用版ランナー スクリプト
multihost_runner.py
を使用 - 本番環境ランナー スクリプト
multihost_job.py
を使用 - 手動アプローチを使用
試験運用版ランナー スクリプト
multihost_runner.py
スクリプトは、既存のマルチスライス環境にコードを配布し、各ホストでコマンドを実行して、ログをコピーし、各コマンドのエラー ステータスを追跡します。multihost_runner.py
スクリプトについては、MaxText の README をご覧ください。
multihost_runner.py
では永続的な SSH 接続が維持されるため、比較的小規模な、比較的短時間で実行されるテストにのみ適しています。multihost_runner.py チュートリアルの手順をワークロードとハードウェア構成に合わせて調整できます。
本番環境ランナー スクリプト
ハードウェアの障害やその他のプリエンプションに対する復元力が必要な本番環境ジョブの場合は、Create Queued Resource API と直接統合することをおすすめします。実用的な例として、Google は multihost_job.py
を提供しています。これにより、適切な起動スクリプトで Created Queued Resource API 呼び出しをトリガーし、プリエンプション時にトレーニングと再開を行います。multihost_job.py
スクリプトについては、MaxText の README をご覧ください。
multihost_job.py
は実行ごとにリソースをプロビジョニングする必要があるため、multihost_runner.py
のような速い反復サイクルは提供されません。
手動アプローチ
マルチスライス構成でカスタム ワークロードを実行するには、multihost_runner.py または multihost_job.py を使用するか、適宜変更することをおすすめします。ただし、QR コマンドを直接使用して環境をプロビジョニングして管理する場合は、マルチスライス環境を管理するをご覧ください。
マルチスライス環境を管理する
MaxText リポジトリで提供されるツールを使用せずに、QR を手動でプロビジョニングして管理するには、以下のセクションをご覧ください。
キューに格納されるリソースを作成する
gcloud
容量をプロビジョニングする前に、次の環境変数を設定します。
$ export YOUR_QR_ID=your-queued-resource-id $ export PROJECT=your-project-name $ export ZONE=us-central2-b $ export NETWORK_NAME=your-network-name $ export SUBNETWORK_NAME=your-subnetwork-name $ export RUNTIME_VERSION=tpu-ubuntu2204-base $ export ACCELERATOR_TYPE=v4-16 $ export SLICE_COUNT=4 $ export STARTUP_SCRIPT="#!/bin/bash\n ..." $ gcloud config set project project-name $ gcloud config set compute/zone zone
変数の説明
入力 説明 YOUR_QR_ID キューに格納されたリソースのユーザー割り当て ID。 プロジェクト Google Cloud プロジェクト名 ZONE リソースを作成するゾーンを指定します。 NETWORK_NAME VPC ネットワークの名前。 SUBNETWORK_NAME VPC ネットワーク内のサブネットの名前 RUNTIME_VERSION Cloud TPU ソフトウェアのバージョン。 ACCELERATOR_TYPE v4-16 EXAMPLE_TAG_1, EXAMPLE_TAG_2 … ネットワーク ファイアウォールの有効なソースやターゲットを識別するために使用されるタグ SLICE_COUNT スライス数。最大 256 スライスに制限されています。 STARTUP_SCRIPT 起動スクリプトを指定すると、TPU スライスがプロビジョニングまたは再起動されたときにスクリプトが実行されます。 次のコマンドを使用して、キューに格納されたリソース リクエストを作成します。
$ gcloud compute tpus queued-resources \ create ${YOUR_QR_ID} \ --project your-project-id \ --zone your-zone \ --node-count ${SLICE_COUNT} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --network ${NETWORK_NAME} \ --subnetwork ${SUBNETWORK_NAME} \ --tags ${EXAMPLE_TAG_1},${EXAMPLE_TAG_2} \ --metadata=startup-script='${STARTUP_SCRIPT}' [--reserved|--spot]
コマンドフラグの説明
YOUR_QR_ID
- キューに格納されたリソース リクエストを識別するユーザー定義文字列。
project
- キューに入れられたリソース リクエストを作成する Google Cloud プロジェクト。
zone
- キューに入れられたリソースを作成する Google Cloud ゾーン。
node-count
- 作成するスライス数。
accelerator-type
- アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。マルチスライスは、Cloud TPU v4 以降の TPU バージョンでのみサポートされます。TPU のバージョンごとにサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
runtime-version
- Cloud TPU ソフトウェアのバージョン。
network
- TPU リソースをアタッチする VPC ネットワークの名前。
subnetwork
- TPU リソースをアタッチする VPC サブネットの名前。
reserved
- スライスの作成時に予約を使用します。
spot
- スライスの作成時に Spot VM を使用します。
--reserved
、--spot
、またはデフォルトのオンデマンド割り当てを選択する前に、それぞれの割り当てがあることを確認してください。割り当てタイプの詳細については、割り当てポリシーをご覧ください。
curl
queued-resource-req.json
という名前のファイルを作成して、次の JSON をコピーします。{ "guaranteed": { "reserved": true }, "tpu": { "node_spec": [ { "parent": "projects/your-project-number/locations/your-zone", "node": { "accelerator_type": "accelerator-type", "runtime_version": "tpu-vm-runtime-version", "network_config": { "network": "your-network-name", "subnetwork": "your-subnetwork-name", "enable_external_ips": true }, "tags" : ["example-tag-1"] "metadata": { "startup-script": "your-startup-script" } }, "multi_node_params": { "node_count": slice-count, "node_id_prefix": "your-queued-resource-id" } } ] } }
次の値を置き換えます。
- your-project-number - Google Cloud プロジェクト番号。
- your-zone - キューに登録するリソースを作成するゾーン
- accelerator-type - 単一スライスのバージョンとサイズ。マルチスライスは、Cloud TPU v4 以降の TPU バージョンでのみサポートされます。
- tpu-vm-runtime-version - 使用する TPU VM ランタイム バージョン。
- your-network-name - 省略可。キューに入れられたリソースが接続されるネットワーク
- your-subnetwork-name - 省略可。キューに入れられたリソースが接続されるサブネットワーク
- example-tag-1 - 省略可、任意のタグ文字列
- your-startup-script - キューに入れられたリソースが割り振られるときに実行される起動スクリプト
- slice-count - マルチスライス環境内の TPU スライスの数
- YOUR_QR_ID - キューに格納されたリソースのユーザー指定 ID
詳細については、利用可能なすべてのオプションに関する、REST キューに格納されたリソース API のドキュメントをご覧ください。
Spot 容量を使用するには、次のように置き換えます。
"spot": {}
の"guaranteed": { "reserved": true }
行を削除して、デフォルトのオンデマンド容量を使用します。
JSON ペイロードを使用して、キューに格納されたリソース作成リクエストを送信します。
$ curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" \ -H "Content-Type: application/json" \ -d @queuedresourcereq.json \ https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/queuedResources\?queued_resource_id\=YOUR_QR_ID
次の値を置き換えます。
- your-project-id - Google Cloud プロジェクト ID。
- your-zone - キューに登録するリソースを作成するゾーン
- YOUR_QR_ID - キューに格納されたリソースのユーザー指定 ID
レスポンスは次のようになります。
{ "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qr-guid>", "metadata": { "@type": "type.googleapis.com/google.cloud.common.OperationMetadata", "createTime": "2023-11-01T00:17:05.742546311Z", "target": "projects/<your-project-id>/locations/<your-zone>/queuedResources/<your-qa-id>", "verb": "create", "cancelRequested": false, "apiVersion": "v2alpha1" }, "done": false }
name
属性の文字列値の末尾にある GUID 値を使用して、キューに登録されたリソース リクエストに関する情報を取得します。
Console
Google Cloud コンソールの [TPU] ページに移動します。
[TPU を作成] をクリックします。
[名前] フィールドに、TPU の名前を入力します。
[ゾーン] ボックスで、TPU を作成するゾーンを選択します。
[TPU タイプ] ボックスで、アクセラレータ タイプを選択します。アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。マルチスライスは、Cloud TPU v4 以降の TPU バージョンでのみサポートされています。TPU のバージョンごとにサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
[TPU ソフトウェア バージョン] ボックスで、ソフトウェア バージョンを選択します。Cloud TPU VM の作成時には、インストールされる TPU ランタイム バージョンが TPU ソフトウェア バージョンによって指定されます。詳細については、TPU VM イメージをご覧ください。
[キューイングを有効にする] トグルをクリックします。
[キューに登録されたリソース名] フィールドに、キューに登録されたリソース リクエストの名前を入力します。
[マルチスライス TPU にする] チェックボックスをオンにします。
[スライス数] フィールドに、作成するスライス数を入力します。
[作成] をクリックして、キューに格納されたリソース リクエストを作成します。
キューに格納されているリソースのステータスを取得する
gcloud
$ gcloud compute tpus queued-resources describe ${YOUR_QR_ID}
キューに格納されたリソースが ACTIVE
状態の場合、出力は次のようになります。
... state: state: ACTIVE ...
curl
$ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/your-project-id/locations/your-zone/queuedResources/${YOUR_QR_ID}
キューに格納されたリソースが ACTIVE
状態の場合、出力は次のようになります。
{ "name": your-queued-res, "tpu": { "nodeSpec": [ { ... // node 1 }, { ... // node 2 }, ... ] }, ... "state": "ACTIVE" }
Console
Google Cloud コンソールの [TPU] ページに移動します。
[キューに登録されたリソース] タブをクリックします。
キューに登録されているリソース リクエストの名前をクリックします。
TPU がプロビジョニングされたら、TPU ページに移動して TPU を見つけ、対応するキューに格納されたリソース リクエストの名前をクリックすると、キューに格納されたリソース リクエストの詳細を確認することもできます。
まれに、キューに登録されているリソースが FAILED
状態である一方、一部のスライスが ACTIVE
になる場合があります。この場合は、作成されたリソースを削除してから、数分後にもう一度試すか、Google Cloud サポートにお問い合わせください。
SSH を使用して依存関係をインストールする
TPU Pod スライスで JAX コードを実行するでは、単一スライスで SSH を使用して TPU VM に接続する方法について説明します。SSH 経由でマルチスライス環境内のすべての TPU VM に接続し、依存関係をインストールするには、次の gcloud
コマンドを使用します。
$ gcloud compute tpus queued-resources ssh ${YOUR_QR_ID} \ --zone your-zone \ --node=all \ --worker=all \ --command="command-to-run" --batch-size=4
この gcloud
コマンドは、SSH を使用して、指定されたコマンドを QR 内のすべてのワーカーとノードに送信します。このコマンドは、4 つのグループにバッチ化され、同時に送信されます。現在のバッチの実行が完了すると、コマンドの次のバッチが送信されます。いずれかのコマンドでエラーが発生すると、処理が停止し、それ以降のバッチは送信されません。詳細については、キューに格納されたリソースの API リファレンスをご覧ください。使用しているスライスの数がローカル コンピュータのスレッド制限(バッチ処理の制限とも呼ばれる)を超えると、デッドロックが発生します。たとえば、ローカルマシンのバッチ処理の上限が 64 であるとします。64 を超えるスライス(100 など)でトレーニング スクリプトを実行しようとすると、SSH コマンドでスライスがバッチに分割されます。最初の 64 スライスのバッチでトレーニング スクリプトを実行し、スクリプトが完了するまで待ってから、残りの 36 スライスのバッチでスクリプトを実行します。ただし、残りの 36 スライスがスクリプトの実行を開始するまで、最初の 64 スライスのバッチは完了できず、デッドロックが発生します。
このシナリオを回避するには、--command
フラグで指定したスクリプト コマンドにアンパサンド(&
)を追加して、各 VM でトレーニング スクリプトをバックグラウンドで実行します。これを行う場合、スライスの最初のバッチでトレーニング スクリプトを開始すると、制御が直ちに SSH コマンドに戻ります。SSH コマンドは、残りの 36 スライスのバッチでトレーニング スクリプトの実行を開始できます。バックグラウンドでコマンドを実行する場合は、stdout
ストリームと stderr
ストリームを適切にパイプする必要があります。同じ QR 内で並列処理を増やすには、--node
パラメータを使用して特定のスライスを選択します。
ネットワーク設定
次の手順を実行して、TPU スライスが相互に通信できることを確認します。各スライスに JAX をインストールします。 詳細については、Cloud TPU Pod スライスで JAX コードを実行するをご覧ください。len(jax.devices())
がマルチスライス環境のチップ数と等しいことをアサートします。これを行うには、各スライスで次のコマンドを実行します。
$ python3 -c 'import jax; print(jax.devices())'
このコードを v4-16 の 4 つのスライスで実行する場合、スライスごとに 8 つのチップと 4 つのスライスがあり、合計 32 チップ(デバイス)が jax.devices()
によって返されます。
キューに格納されているリソースを一覧表示する
gcloud
キューに格納されたリソースの状態を確認するには、queued-resources list
コマンドを使用します。
$ gcloud compute tpus queued-resources list
出力は次のようになります。
NAME ZONE NODE_COUNT ACCELERATOR_TYPE STATE ... que-res-id us-central2-b 4 v4-16 ACTIVE ...
Console
Google Cloud コンソールの [TPU] ページに移動します。
[キューに登録されているリソース] タブをクリックします。
プロビジョニングされた環境でジョブを開始する
ワークロードを手動で実行するには、SSH 経由で各スライスのすべてのホストに接続し、すべてのホストで次のコマンドを実行します。
$ gcloud compute tpus tpu-vm ssh YOUR_QR_ID \ --zone=your-zone \ --worker=all \ --node=all \ --command="command-to-run"
QR のリセット
ResetQueuedResource
API を使用すると、ACTIVE
QR 内のすべての VM をリセットできます。VM をリセットすると、マシンのメモリが強制的に消去され、VM は初期状態にリセットされます。ローカルに保存されているデータはそのまま維持され、リセット後に起動スクリプトが呼び出されます。ResetQueuedResource
API は、すべての TPU を再起動する場合に便利です。たとえば、トレーニングが停止していて、すべての VM をリセットする方がデバッグよりも簡単な場合などです。
すべての VM の再起動は並行して行われ、ResetQueuedResource
オペレーションの完了には 1 ~ 2 分かかります。API を呼び出すには、次のコマンドを使用します。
$ gcloud compute tpus queued-resources reset YOUR_QR_ID
キューに格納されているリソースの削除
トレーニング セッションの終了時にリソースを解放するには、キューに格納されたリソースを削除します。削除が完了するまでに 2 ~ 5 分かかります。gcloud CLI を使用している場合は、オプションの --async
フラグを使用して、このコマンドをバックグラウンドで実行できます。
gcloud
$ gcloud compute tpus queued-resources \ delete YOUR_QR_ID --force (--async)
Console
Google Cloud コンソールの [TPU] ページに移動します。
[キューに登録されているリソース] タブをクリックします。
キューに登録されているリソース リクエストの横にあるチェックボックスをオンにします。
[
削除] をクリックします。
障害からの自動回復
停止した場合、マルチスライスでは、影響を受けたスライスを介入なしで修復し、その後、すべてのスライスをリセットできます。影響を受けたスライスが新しいスライスに置き換えられ、残りの正常なスライスがリセットされます。置き換えるスライスを割り当てるための容量がない場合、トレーニングは停止します。
中断後にトレーニングを自動的に再開するには、最後に保存されたチェックポイントをチェックし、読み込む起動スクリプトを指定する必要があります。起動スクリプトは、スライスが再割り当てされるたび、または VM がリセットされるたびに自動的に実行されます。create QR request API に送信する JSON ペイロードで起動スクリプトを指定します。
次の起動スクリプト(QR を作成するで使用)を使用すると、MaxText トレーニング中に、障害から自動的に回復し、Cloud Storage バケットに保存されているチェックポイントからトレーニングを再開できます。
{ "tpu": { "node_spec": [ { ... "metadata": { "startup-script": "#! /bin/bash \n pwd \n runuser -l user1 -c 'cd /home/user1/MaxText && python3 MaxText/train.py MaxText/configs/base.yml run_name=run_test_failure_recovery dcn_data_parallelism=4 ici_fsdp_parallelism=8 steps=10000 save_period=10 base_output_directory='gs://user1-us-central2'' EOF" } ... } ] } }
これを試す前に、MaxText リポジトリのクローンを作成してください。
プロファイリングとデバッグ
プロファイリングは、シングルスライス環境とマルチスライス環境で同じです。詳細については、JAX プログラムのプロファイリングをご覧ください。
研修を最適化する
最大パフォーマンスのためのマルチスライスを使用したシャーディング
マルチスライス環境で最大のパフォーマンスを実現するには、複数のスライスにわたってシャーディングする方法を考慮する必要があります。通常、3 つの選択肢があります(データ並列処理、完全にシャーディングされたデータ並列処理、パイプライン並列処理)。モデル間ディメンション(テンソル並列処理とも呼ばれます)全体で活性化をシャーディングすることはおすすめしません。スライス間帯域幅が大きすぎるためです。 上記のどの戦略でも、過去にうまくいったスライス内で同じシャーディング戦略を使用できます。
最初は純粋なデータ並列処理から始めることをおすすめします。完全にシャーディングされたデータ並列処理を使用すると、メモリ使用量を削減できます。ただし、スライス間の通信で DCN ネットワークを使用するため、ワークロードが遅くなるという欠点があります。パイプラインの並列処理は、バッチサイズに基づいて必要な場合にのみ使用します(以下で分析)。
データ並列処理を使用する場面
純粋なデータ並列処理は、ワークロードが正常に実行されているが、複数のスライスにスケーリングしてパフォーマンスを向上させる必要がある場合に適しています。
複数のスライスで強力なスケーリングを行うには、DCN 全体で all-reduce を実行するために必要な時間が、バックワードパスを実行するために必要な時間よりも短くなければなりません。DCN はスライス間の通信に使用され、ワークロードのスループットの制限要因となります。
各 v4 TPU チップは、ピーク時に 275 x 1012 FLOPS/秒で動作します。
TPU ホストごとに 4 つのチップがあり、各ホストの最大ネットワーク帯域幅は 50 Gbps です。
つまり、算術強度は 4 × 275 × 10 12 FLOPS ÷ 50 Gbps = 22000 FLOPS / ビットです。
モデルでは、ステップごとに各パラメータに 32 ~ 64 ビットの DCN 帯域幅が使用されます。 2 つのスライスを使用する場合、モデルは 32 ビットの DCN 帯域幅を使用します。3 つ以上のスライスを使用する場合、コンパイラによって完全なシャッフル all-reduce オペレーションが実行され、ステップごとに各パラメータに対して最大 64 ビットの DCN 帯域幅が使用されます。各パラメータに必要な FLOPS の量はモデルによって異なります。具体的には、Transformer ベースの言語モデルの場合、フォワードパスとバックワードパスに必要な FLOPS の数は約 6 x B x P です。ここで、
- B はトークンのバッチサイズ
- P はパラメータの数
パラメータあたりの FLOPS の数は 6 * B
で、バックワードパス中のパラメータあたりの FLOPS の数は 4 * B
です。
複数のスライス間で強力なスケーリングを行うには、オペレーションの強度が TPU ハードウェアの算術強度を超えていることを確認します。動作の強度を計算するには、バックワードパス中のパラメータあたりの FLOPS 数を、1 ステップあたりの各パラメータごとのネットワーク帯域幅(ビット数)で割ります。
Operational Intensity = FLOPSbackwards_pass / DCN bandwidth
したがって、Transformer ベースの言語モデルの場合、2 つのスライスを使用する場合は次のようになります。
Operational intensity = 4 * B / 32
2 つを超えるスライスを使用している場合: Operational intensity = 4 * B/64
これにより、Transformer ベースの言語モデルの最小バッチサイズは 176,000 ~ 352,000 になります。DCN ネットワークは一時的にパケットをドロップできるため、Pod あたりのバッチサイズが 350, 000(2 つの Pod)以上で 700, 000(多数の Pod)の場合にのみ、データの並列処理をデプロイしてデータの並列処理を行うことをおすすめします。
他のモデルのアーキテクチャでは、(プロファイラを使用してタイミングを設定するか、FLOPS をカウントして)スライスあたりのバックワードパスのランタイムを推定する必要があります。次に、DCN を介した all-reduce の予想実行時間と比較して、データ並列処理が適切かどうかを正確に推定できます。
完全にシャーディングされたデータ並列処理(FSDP)を使用する場面
完全にシャーディングされたデータ並列処理(FSDP)は、データ並列処理(ノード間でのデータのシャーディング)とノード間の重みのシャーディングを組み合わせたものです。フォワードパスとバックワードパスのオペレーションごとに、すべてのスライスが all-gather され、各スライスに必要な重みが付けられます。all-reduce を使用して勾配を同期させる代わりに、勾配を生成するときに、reduce-scatter が行われます。このようにして、各スライスは、担当する重みの勾配のみを取得します。
データ並列処理と同様に、FSDP では、グローバル バッチサイズをスライス数に比例してスケーリングする必要があります。FSDP では、スライスの数を増やすとメモリ負荷が軽減されます。これは、スライスあたりの重みとオプティマイザーの状態の数が減少しても、ネットワーク トラフィックの増大と、遅延によるブロックの可能性が増大するためです。
実際には、スライスあたりのバッチ数を増やす場合、バックワードパス中の再実体化を最小限に抑えるためにより多くのアクティベーションを保存する場合、またはニューラル ネットワークのパラメータ数を増やす場合に、スライス間の FSDP が最適です。
FSDP の all-gather オペレーションと all-reduce オペレーションは DP のオペレーションと同様であるため、前のセクションで説明した方法と同じように、FSDP ワークロードが DCN のパフォーマンスによって制限されているかどうかを判断できます。
パイプライン並列処理を使用する場面
パイプラインの並列処理は、優先される最大バッチサイズよりも大きいグローバル バッチサイズを必要とする別の並列処理戦略で高いパフォーマンスを達成するときに重要になります。パイプラインの並列処理により、パイプラインを構成するスライスがバッチを「共有」できます。ただし、パイプラインの並列処理には次の 2 つの大きなデメリットがあります。
- チップがデータを待機しているため、アイドル状態になる「パイプライン バブル」が発生します。
- 効果的なバッチサイズ、算術強度を低減し、最終的には FLOP 使用率をモデル化するマイクロバッチ処理を必要とします。
パイプラインの並列処理は、他の並列処理戦略で必要とされるグローバル バッチサイズが大きくなりすぎる場合にのみ使用してください。パイプラインの並列処理を試す前に、高パフォーマンスの FSDP を達成するために必要なバッチサイズでサンプルあたりの収束が低下するかどうかを実験的に確認することをおすすめします。FSDP はモデルの FLOPS 使用率を高める傾向がありますが、バッチサイズの増加に伴いサンプルあたりの収束が遅くなる場合は、パイプラインの並列処理の方が適している場合があります。ほとんどのワークロードは、パイプラインの並列処理のメリットを享受できないほど大きなバッチサイズを許容できますが、ワークロードによって異なる場合があります。
パイプラインの並列処理が必要な場合は、データ並列処理または FSDP と組み合わせることをおすすめします。これにより、DCN レイテンシがスループットの要因にならない程度まで、パイプラインごとのバッチサイズを増やしながら、パイプラインの深さを最小限に抑えることができます。具体的には、N 個のスライスがある場合は、深さ 2 のパイプラインとデータ並列処理の N/2 のレプリカを検討し、次に深さ 4 のパイプラインとデータ並列処理の N/4 のレプリカを検討するなど、DCN 集合をバックワードパス内の算術の背後に隠すことができるほど、パイプラインあたりのバッチが十分大きくなるまで続けます。これにより、パイプラインの並列処理によって生じる速度低下を最小限に抑えながら、グローバル バッチサイズの上限を超えてスケーリングできます。
マルチスライスに関するベスト プラクティス
データ読み込み
トレーニング中は、データセットからバッチを繰り返し読み込み、モデルにフィードします。TPU の作業不足を回避するには、バッチをホスト間でシャーディングする効率的な非同期データローダーが必要です。MaxText の現在のデータローダーでは、各ホストがサンプルの同等のサブセットを読み込みます。このソリューションはテキストには適していますが、モデル内での再シャーディングが必要です。また、MaxText では、データ イテレータがプリエンプションの前後に同じデータを読み込むことができる決定的スナップショットはまだ提供されていません。
チェックポインティング
Orbax チェックポインティング ライブラリは、JAX PyTrees をローカル ストレージまたは Google Cloud ストレージにチェックポインティングするためのプリミティブを提供します。checkpointing.py
では、MaxText への同期チェックポインティングを備えたリファレンスのインテグレーションが提供されます。
サポートされている構成
図形
すべてのスライスが同じ形状(たとえば同じ AcceleratorType
)である必要があります。
異種スライス形状はサポートされていません。
オーケストレーション
オーケストレーションは GKE でサポートされています。詳細については、GKE の TPU をご覧ください。
フレームワーク
マルチスライスは、JAX と PyTorch のワークロードのみをサポートします。
並列処理
データ並列処理でマルチスライスをテストすることをおすすめします。マルチスライスを使用したパイプラインの並列処理の実装の詳細については、Google Cloud アカウント担当者にお問い合わせください。
サポートとフィードバック
フィードバックをぜひお寄せください。フィードバックを共有したり、サポートをリクエストしたりするには、Cloud TPU サポートまたはフィードバック フォームを使用してご連絡ください。