JAX-Code auf TPU Pod-Slices ausführen

Wenn Ihr JAX-Code auf einem einzelnen TPU-Board ausgeführt wird, können Sie ihn vergrößern, indem Sie ihn auf einem TPU Pod-Slice ausführen. TPU Pod-Slices sind mehrere TPU-Boards, die über dedizierte Hochgeschwindigkeits-Netzwerkverbindungen miteinander verbunden sind. Dieses Dokument ist eine Einführung in die Ausführung von JAX-Code auf TPU Pod-Slices. Ausführlichere Informationen finden Sie unter JAX in Umgebungen mit mehreren Hosts und mehreren Prozessen verwenden.

Wenn Sie bereitgestelltes NFS für die Datenspeicherung verwenden möchten, müssen Sie OS Login für alle TPU-VMs im Pod-Slice festlegen. Weitere Informationen finden Sie unter NFS zur Datenspeicherung verwenden.

TPU Pod-Slice erstellen

Folgen Sie der Anleitung unter Konto und Cloud TPU-Projekt einrichten, bevor Sie die Befehle in diesem Dokument ausführen. Führen Sie auf Ihrem lokalen Computer die folgenden Befehle aus.

Erstellen Sie ein TPU Pod-Slice mit dem Befehl gcloud. Verwenden Sie beispielsweise den folgenden Befehl, um ein v2-32-Pod-Slice zu erstellen:

$ gcloud alpha compute tpus tpu-vm create tpu-name \
  --zone europe-west4-a \
  --accelerator-type v2-32 \
  --version v2-alpha

JAX auf dem Pod-Slice installieren

Nachdem Sie das TPU Pod-Slice erstellt haben, müssen Sie JAX auf allen Hosts im TPU Pod-Slice installieren. Sie können JAX auf allen Hosts mit einem einzigen Befehl mithilfe der Option --worker=all installieren:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

JAX-Code auf dem Pod-Slice ausführen

Wenn Sie JAX-Code auf einem TPU Pod-Slice ausführen möchten, müssen Sie den Code auf jedem Host im TPU Pod-Slice ausführen. Dies bedeutet, dass Sie eine SSH-Verbindung zu jedem Host herstellen und den JAX-Code auf jedem Host ausführen müssen. Der folgende Python-Code zeigt, wie Sie mit der Option --worker=all des Befehls gcloud eine einfache JAX-Berechnung auf einem TPU Pod-Slice ausführen.

Code vorbereiten

Sie benötigen die gcloud-Version >= 344.0.0 (für den Befehl scp). Verwenden Sie gcloud --version, um Ihre gcloud-Version zu prüfen, und führen Sie bei Bedarf gcloud components upgrade aus.

Schreiben Sie example.py auf den lokalen Computer:

   cat > example.py << EOF

   # The following code snippet will be run on all TPU hosts
   import jax

   # The total number of TPU cores in the Pod
   device_count = jax.device_count()

   # The number of TPU cores attached to this host
   local_device_count = jax.local_device_count()

   # The psum is performed over all mapped devices across the Pod
   xs = jax.numpy.ones(jax.local_device_count())
   r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

   # Print from a single host to avoid duplicated output
   if jax.process_index() == 0:
       print('global device count:', jax.device_count())
       print('local device count:', jax.local_device_count())
       print('pmap result:', r)
   EOF

Kopieren Sie example.py in alle VMs im Pod-Slice.

$ gcloud alpha compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all --zone=europe-west4-a

Wenn Sie den Befehl scp zum ersten Mal verwenden, wird möglicherweise ein Fehler wie der folgende angezeigt:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH agent.
Please run "ssh-add /.../.ssh/google_compute_engine" to add it, and try again.
Führen Sie den Befehl ssh-add aus, wie in der Fehlermeldung angezeigt, und führen Sie den Befehl noch einmal aus, um den Fehler zu beheben.

Code im Pod-Slice ausführen

Starten Sie das Programm example.py auf jeder VM:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a --worker=all --command "python3 example.py"

Ausgabe (generiert mit einem v2-32-Pod-Slice):

global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]

So können Sie den JAX-Python-Code auf jedem Host ausführen, Sie können aber auch eine beliebige andere Methode verwenden. Unabhängig von der Ausführung hängt der obige jax.device_count()-Aufruf jedoch ab, bis er auf jedem Host im Pod-Slice aufgerufen wird, da alle Hosts vorhanden sein müssen, um die TPU-Laufzeit zu initialisieren.

Bereinigen

Wenn Sie fertig sind, können Sie Ihre TPU-VM-Ressourcen mit dem Befehl gcloud freigeben:

$ gcloud alpha compute tpus tpu-vm delete tpu-name \
  --zone europe-west4-a