JAX-Code auf TPU-Pod-Slices ausführen

Nachdem Sie den JAX-Code auf einem einzelnen TPU-Board ausgeführt haben, können Sie den Code skalieren, indem Sie ihn auf einem TPU-Pod-Slice ausführen. TPU Pod-Slices sind mehrere TPU-Boards, die über dedizierte Hochgeschwindigkeits-Netzwerkverbindungen miteinander verbunden sind. Dieses Dokument bietet eine Einführung zum Ausführen von JAX-Code auf TPU Pod-Slices. Ausführlichere Informationen finden Sie unter JAX in Umgebungen mit mehreren Hosts und mehreren Prozessen verwenden.

Wenn Sie bereitgestellten NFS für die Datenspeicherung verwenden möchten, müssen Sie OS Login für alle TPU-VMs im Pod-Slice festlegen. Weitere Informationen finden Sie unter NFS für die Datenspeicherung verwenden.

TPU Pod-Slice erstellen

Bevor Sie die Befehle in diesem Dokument ausführen, müssen Sie die Anweisungen unter Konto und Cloud TPU-Projekt einrichten befolgt haben. Führen Sie die folgenden Befehle auf Ihrem lokalen Computer aus.

Erstellen Sie mit dem Befehl gcloud ein TPU Pod-Slice. Verwenden Sie beispielsweise den folgenden Befehl, um ein Pod-Slice für v4-32 zu erstellen:

$ gcloud compute tpus tpu-vm create tpu-name  \
  --zone=us-central2-b \
  --accelerator-type=v4-32  \
  --version=tpu-ubuntu2204-base 

JAX auf dem Pod-Slice installieren

Nachdem Sie das TPU Pod-Slice erstellt haben, müssen Sie JAX auf allen Hosts im TPU Pod-Slice installieren. Sie können JAX auf allen Hosts mit einem einzigen Befehl mithilfe der Option --worker=all installieren:

  gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b --worker=all --command="pip install \
  --upgrade 'jax[tpu]>0.3.0' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

JAX-Code auf dem Pod-Slice ausführen

Um JAX-Code auf einem TPU Pod-Slice auszuführen, müssen Sie den Code auf jedem Host im TPU Pod-Slice ausführen. Der jax.device_count()-Aufruf reagiert nicht mehr, bis er auf jedem Host im Pod-Slice aufgerufen wird. Das folgende Beispiel veranschaulicht, wie eine einfache JAX-Berechnung auf einem TPU-Pod-Slice ausgeführt wird.

Code vorbereiten

Sie benötigen Version gcloud ≥ 344.0.0 (für den Befehl scp). Verwenden Sie gcloud --version, um Ihre gcloud-Version zu prüfen, und führen Sie bei Bedarf gcloud components upgrade aus.

Erstellen Sie eine Datei mit dem Namen example.py und dem folgenden Code:

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

example.py auf alle TPU-Worker-VMs im Pod-Slice kopieren

$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all \
  --zone=us-central2-b

Wenn Sie den Befehl scp noch nicht verwendet haben, wird möglicherweise ein Fehler wie der folgende angezeigt:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Um den Fehler zu beheben, führen Sie den Befehl ssh-add wie in der Fehlermeldung angegeben aus und wiederholen Sie den Befehl.

Code auf dem Pod-Slice ausführen

Starten Sie auf jeder VM das Programm example.py:

$ gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b \
  --worker=all \
  --command="python3 example.py"

Ausgabe (erzeugt mit einem v4-32-Pod-Slice):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

Bereinigen

Wenn Sie fertig sind, können Sie Ihre TPU-VM-Ressourcen mit dem Befehl gcloud freigeben:

$ gcloud compute tpus tpu-vm delete tpu-name \
  --zone=us-central2-b