TPU Pod スライスでの JAX コードの実行
JAX コードを単一の TPU ボードで実行したら、TPU Pod スライスで実行してコードをスケールアップできます。TPU Pod スライスは、専用の高速ネットワーク接続で互いに接続された複数の TPU ボードです。このドキュメントでは、TPU Pod スライスでの JAX コードの実行についての概要を説明します。詳しくは、マルチホスト環境とマルチプロセス環境での JAX の使用をご覧ください。
データ ストレージにマウントされた NFS を使用する場合は、Pod スライス内のすべての TPU VM に OS Login を設定する必要があります。詳細については、データ ストレージに NFS を使用するをご覧ください。TPU Pod スライスの作成
このドキュメントのコマンドを実行する前に、アカウントと Cloud TPU プロジェクトを設定するの手順に従ってください。 ローカルマシンで次のコマンドを実行します。
gcloud
コマンドを使用して TPU Pod スライスを作成します。たとえば、v4-32 Pod スライスを作成するには、次のコマンドを使用します。
$ gcloud compute tpus tpu-vm create tpu-name \ --zone=us-central2-b \ --accelerator-type=v4-32 \ --version=tpu-ubuntu2204-base
Pod スライスに JAX をインストールする
TPU Pod スライスを作成したら、TPU Pod スライスのすべてのホストに JAX をインストールする必要があります。--worker=all
オプションを使用すると、1 つのコマンドですべてのホストに JAX をインストールできます。
gcloud compute tpus tpu-vm ssh tpu-name \ --zone=us-central2-b --worker=all --command="pip install \ --upgrade 'jax[tpu]>0.3.0' \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
Pod スライスで JAX コードを実行する
TPU Pod スライスで JAX コードを実行するには、TPU Pod スライスの各ホストでコードを実行する必要があります。jax.device_count()
呼び出しは、Pod スライスの各ホストで呼び出されるまで応答しなくなります。次の例は、TPU Pod スライスで単純な JAX 計算を実行する方法を示しています。
コードの準備
scp
コマンドでは、gcloud
バージョン 344.0.0 以降が必要です。gcloud --version
を使用して gcloud
のバージョンを確認し、必要に応じて gcloud components upgrade
を実行します。
次のコードを使用して example.py
という名前のファイルを作成します。
# The following code snippet will be run on all TPU hosts
import jax
# The total number of TPU cores in the Pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
example.py
を Pod スライス内のすべての TPU ワーカー VM にコピーする
$ gcloud compute tpus tpu-vm scp example.py tpu-name: \ --worker=all \ --zone=us-central2-b
以前に scp
コマンドを使用したことがない場合、次のようなエラーが表示されることがあります。
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try again.
このエラーを解決するには、エラー メッセージに示されている ssh-add
コマンドを実行してから、コマンドを再実行します。
Pod スライスでコードを実行する
すべての VM で example.py
プログラムを起動します。
$ gcloud compute tpus tpu-vm ssh tpu-name \ --zone=us-central2-b \ --worker=all \ --command="python3 example.py"
出力(v4-32 Pod スライスで生成)
global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]
クリーンアップ
終了した後は、gcloud
コマンドを使用して TPU VM リソースを解放できます。
$ gcloud compute tpus tpu-vm delete tpu-name \ --zone=us-central2-b