JAX-Code auf TPU-Pod-Slices ausführen
Nachdem Sie den JAX-Code auf einem einzelnen TPU-Board ausgeführt haben, können Sie den Code skalieren, indem Sie ihn auf einem TPU-Pod-Slice ausführen. TPU Pod-Slices sind mehrere TPU-Boards, die über dedizierte Hochgeschwindigkeits-Netzwerkverbindungen miteinander verbunden sind. Dieses Dokument bietet eine Einführung zum Ausführen von JAX-Code auf TPU Pod-Slices. Ausführlichere Informationen finden Sie unter JAX in Umgebungen mit mehreren Hosts und mehreren Prozessen verwenden.
Wenn Sie bereitgestellte NFS für die Datenspeicherung verwenden möchten, müssen Sie OS Login für alle TPU-VMs im Pod-Slice festlegen. Weitere Informationen finden Sie unter NFS als Datenspeicher verwenden.Umgebung einrichten
Führen Sie in Cloud Shell den folgenden Befehl aus, um sicherzustellen, die aktuelle Version von
gcloud
ausführen:$ gcloud components update
Wenn Sie
gcloud
installieren möchten, verwenden Sie den folgenden Befehl:$ sudo apt install -y google-cloud-sdk
Erstellen Sie einige Umgebungsvariablen:
$ export TPU_NAME=tpu-name $ export ZONE=us-central2-b $ export RUNTIME_VERSION=tpu-ubuntu2204-base $ export ACCELERATOR_TYPE=v4-32
TPU Pod-Slice erstellen
Bevor Sie die Befehle in diesem Dokument ausführen, stellen Sie sicher, dass Sie die Konto und Cloud TPU-Projekt einrichten Führen Sie die folgenden Befehle auf Ihrem lokalen Computer aus.
Erstellen Sie mit dem Befehl gcloud
ein TPU Pod-Slice. Verwenden Sie beispielsweise den folgenden Befehl, um ein v4-32-Pod-Slice zu erstellen:
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
JAX auf dem Pod-Slice installieren
Nachdem Sie das TPU Pod-Slice erstellt haben, müssen Sie JAX auf allen Hosts im TPU Pod-Slice installieren. Sie können JAX auf allen Hosts mit einem einzigen Befehl mithilfe der Option --worker=all
installieren:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --worker=all \ --command="pip install \ --upgrade 'jax[tpu]>0.3.0' \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
JAX-Code auf dem Pod-Slice ausführen
Um JAX-Code auf einem TPU Pod-Slice auszuführen, müssen Sie den Code auf jedem Host im TPU Pod-Slice ausführen. Der jax.device_count()
-Aufruf reagiert nicht mehr, bis er auf jedem Host im Pod-Slice aufgerufen wird. Das folgende Beispiel zeigt, wie Sie
eine JAX-Berechnung auf einem TPU-Pod-Slice ausführen.
Code vorbereiten
Sie benötigen die gcloud
-Version >= 344.0.0 (für die
scp
-Befehl).
Prüfen Sie mit gcloud --version
Ihre gcloud
-Version und führen Sie bei Bedarf gcloud components upgrade
aus.
Erstellen Sie eine Datei mit dem Namen example.py
und folgendem Code:
# The following code snippet will be run on all TPU hosts
import jax
# The total number of TPU cores in the Pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
example.py
in alle TPU-Worker-VMs im Pod-Slice kopieren
$ gcloud compute tpus tpu-vm scp example.py ${TPU_NAME} \ --worker=all \ --zone=${ZONE}
Wenn Sie den Befehl scp
noch nicht verwendet haben, wird möglicherweise ein Fehler wie der folgende angezeigt:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
Führen Sie den Befehl ssh-add
wie in der Fehlermeldung angezeigt aus und führen Sie den Befehl noch einmal aus, um den Fehler zu beheben.
Code auf dem Pod-Slice ausführen
Starten Sie auf jeder VM das Programm example.py
:
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --worker=all \ --command="python3 example.py"
Ausgabe (erzeugt mit einem v4-32-Pod-Slice):
global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]
Bereinigen
Wenn Sie mit Ihrer TPU-VM fertig sind, führen Sie die folgenden Schritte aus, um Ihre Ressourcen zu bereinigen.
Trennen Sie die Verbindung zur Compute Engine-Instanz, sofern noch nicht geschehen:
(vm)$ exit
Die Eingabeaufforderung sollte nun
username@projectname
lauten und angeben, dass Sie sich in Cloud Shell befinden.Löschen Sie Ihre Cloud TPU- und Compute Engine-Ressourcen.
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE}
Prüfen Sie, ob die Ressourcen gelöscht wurden. Führen Sie dazu
gcloud compute tpus execution-groups list
aus. Der Löschvorgang kann einige Minuten dauern. Die Ausgabe des folgenden Befehls sollte keine der in dieser Anleitung erstellten Ressourcen enthalten:$ gcloud compute tpus tpu-vm list --zone=${ZONE}