JAX-Code auf TPU-Pod-Slices ausführen
Nachdem Sie den JAX-Code auf einem einzelnen TPU-Board ausgeführt haben, können Sie den Code skalieren, indem Sie ihn auf einem TPU-Pod-Slice ausführen. TPU Pod-Slices sind mehrere TPU-Boards, die über dedizierte Hochgeschwindigkeits-Netzwerkverbindungen miteinander verbunden sind. Dieses Dokument bietet eine Einführung zum Ausführen von JAX-Code auf TPU Pod-Slices. Ausführlichere Informationen finden Sie unter JAX in Umgebungen mit mehreren Hosts und mehreren Prozessen verwenden.
Wenn Sie für die Datenspeicherung bereitgestellten NFS verwenden möchten, müssen Sie für alle TPU-VMs im Pod-Slice OS Login festlegen. Weitere Informationen finden Sie unter NFS als Datenspeicher verwenden.
TPU Pod-Slice erstellen
Bevor Sie die Befehle in diesem Dokument ausführen, prüfen Sie, ob Sie der Anleitung unter Konto und Cloud TPU-Projekt einrichten gefolgt sind. Führen Sie die folgenden Befehle auf Ihrem lokalen Computer aus.
Erstellen Sie mit dem Befehl gcloud
ein TPU Pod-Slice. Verwenden Sie beispielsweise den folgenden Befehl, um ein v2-32-Pod-Slice zu erstellen:
$ gcloud compute tpus tpu-vm create tpu-name \
--zone europe-west4-a \
--accelerator-type v2-32 \
--version tpu-vm-base
JAX auf dem Pod-Slice installieren
Nachdem Sie das TPU Pod-Slice erstellt haben, müssen Sie JAX auf allen Hosts im TPU Pod-Slice installieren. Sie können JAX auf allen Hosts mit einem einzigen Befehl mithilfe der Option --worker=all
installieren:
$ gcloud compute tpus tpu-vm ssh tpu-name \
--zone europe-west4-a \
--worker=all \
--command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
JAX-Code auf dem Pod-Slice ausführen
Um JAX-Code auf einem TPU Pod-Slice auszuführen, müssen Sie den Code auf jedem Host im TPU Pod-Slice ausführen. Der jax.device_count()
-Aufruf reagiert erst wieder, wenn er bei jedem Host im Pod-Slice aufgerufen wird. Das folgende Beispiel zeigt, wie Sie eine einfache JAX-Berechnung auf einem TPU Pod-Slice ausführen können.
Code vorbereiten
Sie benötigen die Version gcloud
>= 344.0.0 (für den Befehl scp
).
Prüfen Sie mit gcloud --version
Ihre gcloud
-Version und führen Sie bei Bedarf gcloud components upgrade
aus.
Schreiben Sie example.py
auf den lokalen Computer:
# The following code snippet will be run on all TPU hosts
import jax
# The total number of TPU cores in the Pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
example.py
in alle VMs im Pod-Slice kopieren
$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
--worker=all --zone=europe-west4-a
Wenn Sie den Befehl scp
noch nicht verwendet haben, wird möglicherweise ein ähnlicher Fehler wie dieser angezeigt:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try again.
Führen Sie den ssh-add
-Befehl wie in der Fehlermeldung gezeigt aus und führen Sie den Befehl noch einmal aus, um den Fehler zu beheben.
Code auf dem Pod-Slice ausführen
Starten Sie auf jeder VM das Programm example.py
:
$ gcloud compute tpus tpu-vm ssh tpu-name \
--zone europe-west4-a --worker=all --command "python3 example.py"
Ausgabe (mit einem v2-32-Pod-Slice erzeugt)
global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]
Bereinigen
Wenn Sie fertig sind, können Sie Ihre TPU-VM-Ressourcen mit dem Befehl gcloud
freigeben:
$ gcloud compute tpus tpu-vm delete tpu-name \
--zone europe-west4-a