Esegui codice JAX sulle sezioni di pod di TPU

Dopo aver eseguito il codice JAX su una singola scheda TPU, puoi fare lo scale up del codice eseguendolo su una sezione di pod TPU. Le sezioni di pod di TPU sono più schede TPU collegate tra loro tramite connessioni di rete dedicate ad alta velocità. Questo documento è un'introduzione all'esecuzione del codice JAX sulle sezioni dei pod di TPU; per informazioni più approfondite, consulta la pagina relativa all'utilizzo di JAX in ambienti multi-host e multi-processo.

Se vuoi utilizzare NFS montato per l'archiviazione dei dati, devi impostare OS Login per tutte le VM TPU nella sezione di pod. Per maggiori informazioni, consulta Utilizzo di un NFS per l'archiviazione dei dati.

Crea una sezione di pod di TPU

Prima di eseguire i comandi in questo documento, assicurati di aver seguito le istruzioni riportate in Configurare un account e un progetto Cloud TPU. Esegui i seguenti comandi sulla tua macchina locale.

Crea una sezione di pod di TPU utilizzando il comando gcloud. Ad esempio, per creare una sezione di pod v4-32, utilizza il seguente comando:

$ gcloud compute tpus tpu-vm create tpu-name  \
  --zone=us-central2-b \
  --accelerator-type=v4-32  \
  --version=tpu-ubuntu2204-base 

Installa JAX nella sezione del pod

Dopo aver creato la sezione del pod di TPU, devi installare JAX su tutti gli host nella sezione del pod TPU. Puoi installare JAX su tutti gli host con un singolo comando utilizzando l'opzione --worker=all:

  gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b --worker=all --command="pip install \
  --upgrade 'jax[tpu]>0.3.0' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

esegui il codice JAX nella sezione del pod

Per eseguire il codice JAX su una sezione di pod di TPU, devi eseguire il codice su ciascun host nella sezione di pod di TPU TPU. La chiamata jax.device_count() smette di rispondere finché non viene richiamata su ciascun host nella sezione di pod. L'esempio seguente illustra come eseguire un semplice calcolo JAX su una sezione di pod di TPU.

Prepara il codice

È necessaria la versione di gcloud >= 344.0.0 (per il comando scp). Usa gcloud --version per controllare la versione di gcloud ed esegui gcloud components upgrade, se necessario.

Crea un file denominato example.py con il seguente codice:

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copia example.py in tutte le VM worker TPU nella sezione di pod

$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all \
  --zone=us-central2-b

Se non hai mai usato il comando scp, potresti visualizzare un errore simile al seguente:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Per risolvere l'errore, esegui il comando ssh-add come visualizzato nel messaggio di errore ed esegui nuovamente il comando.

esegui il codice nella sezione di pod

Avvia il programma example.py su ogni VM:

$ gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b \
  --worker=all \
  --command="python3 example.py"

Output (prodotto con una sezione di pod v4-32):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

Esegui la pulizia

Al termine, puoi rilasciare le risorse VM TPU utilizzando il comando gcloud:

$ gcloud compute tpus tpu-vm delete tpu-name \
  --zone=us-central2-b