Esegui il codice JAX nelle sezioni di pod TPU

Dopo aver eseguito il codice JAX su una singola scheda TPU, puoi eseguirne l'upgrade su una sezione di pod TPU. Le sezioni di pod di TPU sono più schede TPU collegate tra loro tramite connessioni di rete ad alta velocità. Questo documento è un'introduzione alla gestione di JAX sulle sezioni di pod di TPU; per informazioni più approfondite, consulta la sezione Utilizzo di JAX in ambienti multi-host e con più processi.

Se vuoi utilizzare NFS montato per l'archiviazione dei dati, devi impostare OS Login per tutti VM TPU nella sezione del pod. Per ulteriori informazioni, consulta Utilizzo di un NFS per l'archiviazione dei dati.

Configura l'ambiente

  1. In Cloud Shell, esegui questo comando per assicurarti di essere con la versione corrente di gcloud:

    $ gcloud components update

    Se devi installare gcloud, utilizza il seguente comando:

    $ sudo apt install -y google-cloud-sdk
  2. Crea alcune variabili di ambiente:

    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32

Creare un segmento di pod TPU

Prima di eseguire i comandi in questo documento, assicurati di aver seguito le per le istruzioni in Configurare un account e un progetto Cloud TPU. Esegui i seguenti comandi sulla tua macchina locale.

Crea uno slice del pod TPU utilizzando il comando gcloud. Ad esempio, per creare un'istanza La sezione di pod v4-32 utilizza il seguente comando:

$ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
  --zone=${ZONE} \
  --accelerator-type=${ACCELERATOR_TYPE}  \
  --version=${RUNTIME_VERSION} 

Installa JAX nella sezione di pod

Dopo aver creato la sezione di pod di TPU, devi installare JAX su tutti gli host nella TPU Sezione di pod. Puoi installare JAX su tutti gli host con un solo comando utilizzando Opzione --worker=all:

  gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --worker=all \
  --command="pip install \
  --upgrade 'jax[tpu]>0.3.0' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

Esegui il codice JAX nella sezione del pod

Per eseguire il codice JAX su una sezione di pod di TPU, devi eseguire il codice su ciascun host nel Pod di TPU TPU. La chiamata jax.device_count() smette di rispondere finché non viene chiamata su ogni host nel segmento del pod. L'esempio seguente illustra come eseguire un calcolo JAX su una sezione di pod TPU.

Preparare il codice

È necessaria la versione gcloud >= 344.0.0 (per il scp). Utilizza gcloud --version per controllare la versione di gcloud ed eseguire gcloud components upgrade, se necessario.

Crea un file denominato example.py con il seguente codice:

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copia example.py in tutte le VM worker TPU nella sezione del pod

$ gcloud compute tpus tpu-vm scp example.py ${TPU_NAME} \
  --worker=all \
  --zone=${ZONE}

Se non hai mai utilizzato il comando scp, potresti visualizzare un errore simile al seguente:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Per risolvere l'errore, esegui il comando ssh-add visualizzato nel messaggio di errore e riavvialo.

Esegui il codice nella sezione del pod

Avvia il programma example.py su ogni VM:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --worker=all \
  --command="python3 example.py"

Output (generato con una sezione di pod v4-32):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

Esegui la pulizia

Al termine dell'utilizzo della VM TPU, segui questi passaggi per ripulire le risorse.

  1. Disconnettiti dall'istanza Compute Engine, se non lo hai già fatto Fatto:

    (vm)$ exit

    Il tuo prompt dovrebbe ora essere username@projectname, a indicare che ti trovi in Cloud Shell.

  2. Elimina le tue risorse Cloud TPU e Compute Engine.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE}
  3. Verifica che le risorse siano state eliminate eseguendo gcloud compute tpus execution-groups list. La l'eliminazione potrebbe richiedere diversi minuti. L'output del comando seguente non devono includere nessuna delle risorse create in questo tutorial:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE}