JAX-Code auf TPU-Pod-Slices ausführen

Nachdem Sie den JAX-Code auf einem einzelnen TPU-Board ausgeführt haben, können Sie den Code skalieren, indem Sie ihn auf einem TPU-Pod-Slice ausführen. TPU Pod-Slices sind mehrere TPU-Boards, die über dedizierte Hochgeschwindigkeits-Netzwerkverbindungen miteinander verbunden sind. Dieses Dokument bietet eine Einführung zum Ausführen von JAX-Code auf TPU Pod-Slices. Ausführlichere Informationen finden Sie unter JAX in Umgebungen mit mehreren Hosts und mehreren Prozessen verwenden.

Wenn Sie die bereitgestellte NFS-Datenspeicher verwenden möchten, müssen Sie OS Login für alle TPU-VMs im Pod-Slice festlegen. Weitere Informationen finden Sie unter NFS als Datenspeicher verwenden.

TPU Pod-Slice erstellen

Bevor Sie die Befehle in diesem Dokument ausführen, sollten Sie sich an die Anweisungen unter Konto und Cloud TPU-Projekt einrichten gehalten haben. Führen Sie die folgenden Befehle auf Ihrem lokalen Computer aus.

Erstellen Sie mit dem Befehl gcloud ein TPU Pod-Slice. Verwenden Sie beispielsweise den folgenden Befehl, um ein v2-32-Pod-Slice zu erstellen:

$ gcloud compute tpus tpu-vm create tpu-name \
  --zone europe-west4-a \
  --accelerator-type v2-32 \
  --version tpu-vm-base

JAX auf dem Pod-Slice installieren

Nachdem Sie das TPU Pod-Slice erstellt haben, müssen Sie JAX auf allen Hosts im TPU Pod-Slice installieren. Sie können JAX auf allen Hosts mit einem einzigen Befehl mithilfe der Option --worker=all installieren:

$ gcloud compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

JAX-Code auf dem Pod-Slice ausführen

Um JAX-Code auf einem TPU Pod-Slice auszuführen, müssen Sie den Code auf jedem Host im TPU Pod-Slice ausführen. Der Aufruf jax.device_count() reagiert erst dann, wenn er bei jedem Host im Pod-Slice aufgerufen wird. Das folgende Beispiel zeigt, wie eine einfache JAX-Berechnung auf einem TPU Pod-Slice ausgeführt wird.

Code vorbereiten

Sie benötigen die Version gcloud >= 344.0.0 (für den Befehl scp). Prüfen Sie mit gcloud --version Ihre gcloud-Version und führen Sie bei Bedarf gcloud components upgrade aus.

Schreiben Sie example.py auf den lokalen Computer:

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

example.py auf alle VMs im Pod-Slice kopieren

$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all --zone=europe-west4-a

Wenn Sie den Befehl scp noch nicht verwendet haben, erhalten Sie möglicherweise eine Fehlermeldung wie die folgende:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Führen Sie den Befehl ssh-add aus, damit der Fehler behoben wird, und führen Sie den Befehl noch einmal aus.

Code auf dem Pod-Slice ausführen

Starten Sie auf jeder VM das Programm example.py:

$ gcloud compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a --worker=all --command "python3 example.py"

Ausgabe (mit einem v2-32-Pod-Slice erzeugt)

global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]

Bereinigen

Wenn Sie fertig sind, können Sie Ihre TPU-VM-Ressourcen mit dem Befehl gcloud freigeben:

$ gcloud compute tpus tpu-vm delete tpu-name \
  --zone europe-west4-a