TPU Pod スライスでの JAX コードの実行

JAX コードを単一の TPU ボードで実行したら、TPU Pod スライスで実行してコードをスケールアップできます。TPU Pod スライスは、専用の高速ネットワーク接続で互いに接続された複数の TPU ボードです。このドキュメントでは、TPU Pod スライスでの JAX コードの実行についての概要を説明します。詳しくは、マルチホスト環境とマルチプロセス環境での JAX の使用をご覧ください。

データ ストレージにマウントされた NFS を使用する場合は、Pod スライス内のすべての TPU VM に OS Login を設定する必要があります。詳細については、データ ストレージに NFS を使用するをご覧ください。

TPU Pod スライスの作成

このドキュメントのコマンドを実行する前に、アカウントと Cloud TPU プロジェクトを設定するの手順に従ってください。 ローカルマシンで次のコマンドを実行します。

gcloud コマンドを使用して TPU Pod スライスを作成します。たとえば、v4-32 Pod スライスを作成するには、次のコマンドを使用します。

$ gcloud compute tpus tpu-vm create tpu-name  \
  --zone=us-central2-b \
  --accelerator-type=v4-32  \
  --version=tpu-ubuntu2204-base 

Pod スライスに JAX をインストールする

TPU Pod スライスを作成したら、TPU Pod スライスのすべてのホストに JAX をインストールする必要があります。--worker=all オプションを使用すると、1 つのコマンドですべてのホストに JAX をインストールできます。

  gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b --worker=all --command="pip install \
  --upgrade 'jax[tpu]>0.3.0' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

Pod スライスで JAX コードを実行する

TPU Pod スライスで JAX コードを実行するには、TPU Pod スライスの各ホストでコードを実行する必要があります。jax.device_count() 呼び出しは、Pod スライスの各ホストで呼び出されるまで応答しなくなります。次の例は、TPU Pod スライスで単純な JAX 計算を実行する方法を示しています。

コードの準備

scp コマンドでは、gcloud バージョン 344.0.0 以降が必要です。gcloud --version を使用して gcloud のバージョンを確認し、必要に応じて gcloud components upgrade を実行します。

次のコードを使用して example.py という名前のファイルを作成します。

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

example.py を Pod スライス内のすべての TPU ワーカー VM にコピーする

$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all \
  --zone=us-central2-b

以前に scp コマンドを使用したことがない場合、次のようなエラーが表示されることがあります。

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

このエラーを解決するには、エラー メッセージに示されている ssh-add コマンドを実行してから、コマンドを再実行します。

Pod スライスでコードを実行する

すべての VM で example.py プログラムを起動します。

$ gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b \
  --worker=all \
  --command="python3 example.py"

出力(v4-32 Pod スライスで生成)

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

クリーンアップ

終了した後は、gcloud コマンドを使用して TPU VM リソースを解放できます。

$ gcloud compute tpus tpu-vm delete tpu-name \
  --zone=us-central2-b