JAX-Code auf TPU Pod-Slices ausführen

Nachdem der JAX-Code auf einem einzelnen TPU-Board ausgeführt wurde, können Sie Ihren Code skalieren, indem Sie ihn auf einem TPU Pod-Slice ausführen. TPU Pod-Segmente sind mehrere TPU-Boards, die über dedizierte Hochgeschwindigkeits-Netzwerkverbindungen miteinander verbunden sind. Dieses Dokument bietet eine Einführung in die Ausführung von JAX-Code auf TPU Pod-Slices. Ausführlichere Informationen finden Sie unter JAX für Multi-Host- und Multi-Prozess-Umgebungen verwenden.

TPU Pod-Slice erstellen

Sie erstellen mit dem Befehl gcloud ein TPU Pod-Slice. Verwenden Sie beispielsweise den folgenden Befehl, um ein v2-32 Pod-Slice zu erstellen:

$ gcloud alpha compute tpus tpu-vm create tpu-name \
  --zone europe-west4-a \
  --accelerator-type v2-32 \
  --version v2-alpha

JAX im Pod-Segment installieren

Nachdem das TPU Pod-Slice erstellt wurde, müssen Sie JAX auf allen Hosts im TPU Pod-Slice installieren. Mit der Option --worker=all können Sie JAX auf allen Hosts mit einem einzigen Befehl installieren:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command="pip install --upgrade jax jaxlib"

JAX-Code im Pod-Segment ausführen

Um JAX-Code auf einem TPU Pod-Slice auszuführen, müssen Sie den Code auf jedem Host im TPU Pod-Segment ausführen. Das heißt, Sie müssen eine SSH-Verbindung zu jedem Host herstellen und den JAX-Code auf jedem Host ausführen. Der folgende Python-Code zeigt, wie Sie mit der Option --worker=all des Befehls gcloud eine einfache JAX-Berechnung auf einem TPU Pod-Slice ausführen.

Code vorbereiten

$ read -r -d '' PYTHON_CMD << EOF
# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)
EOF

Code im Pod-Segment ausführen

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command "python3 -c \"$PYTHON_CMD\""

Ausgabe (mit einem v2-32-Pod-Segment erzeugt):

global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]

Dies ist eine Möglichkeit, um JAX-Python-Code auf jedem Host auszuführen. Sie können jedoch beliebige Methoden verwenden. Sie führen es jedoch aus, bis der obige jax.device_count()-Aufruf unterbrochen wird, bis er auf jedem Host im Pod-Slice aufgerufen wird, da alle Hosts vorhanden sein müssen, um die TPU-Laufzeit zu initialisieren.

Bereinigen

Wenn Sie fertig sind, können Sie die TPU-VM-Ressourcen mit dem Befehl gcloud freigeben:

$ gcloud alpha compute tpus tpu-vm delete tpu-name \
  --zone europe-west4-a