Esegui un calcolo su una VM Cloud TPU utilizzando JAX

Questo documento fornisce una breve introduzione all'utilizzo di JAX e Cloud TPU.

Prima di seguire questa guida rapida, devi creare un account Google Cloud Platform, installare Google Cloud CLI e configurare il comando gcloud. Per saperne di più, consulta Configurare un account e un progetto Cloud TPU.

Installa Google Cloud CLI

Google Cloud CLI contiene strumenti e librerie per interagire con i prodotti e i servizi Google Cloud. Per ulteriori informazioni, consulta Installare Google Cloud CLI.

Configura il comando gcloud

Esegui i seguenti comandi per configurare gcloud in modo che utilizzi il tuo progetto Google Cloud e installa i componenti necessari per l'anteprima della VM TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

Abilita l'API Cloud TPU

  1. Abilita l'API Cloud TPU utilizzando il seguente comando gcloud in Cloud Shell. Puoi anche attivarla dalla console Google Cloud.

    $ gcloud services enable tpu.googleapis.com
  2. Esegui il comando seguente per creare un'identità di servizio.

    $ gcloud beta services identity create --service tpu.googleapis.com

Crea una VM Cloud TPU con gcloud

Con le VM Cloud TPU, il modello e il codice vengono eseguiti direttamente sulla macchina host TPU. Esegui SSH direttamente nell'host TPU. Puoi eseguire codice arbitrario, installare pacchetti, visualizzare i log e eseguire il debug del codice direttamente sull'host TPU.

  1. Crea la VM TPU eseguendo il seguente comando da Cloud Shell o dal terminale del computer su cui è installato Google Cloud CLI.

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone=us-central1-a \
    --accelerator-type=v3-8 \
    --version=tpu-ubuntu2204-base

    Campi obbligatori

    zone
    La zona in cui prevedi di creare la Cloud TPU.
    accelerator-type
    Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni sui tipi di acceleratori supportati per ogni versione di TPU, consulta Versioni TPU.
    version
    La versione software di Cloud TPU. Per tutti i tipi di TPU, utilizza tpu-ubuntu2204-base.

Connettiti alla VM Cloud TPU

Accedi tramite SSH alla VM TPU utilizzando il seguente comando:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central1-a

Campi obbligatori

tpu_name
Il nome della VM TPU a cui ti stai connettendo.
zone
La zona in cui hai creato la Cloud TPU.

Installa JAX sulla tua VM Cloud TPU

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Controllo del sistema

Verifica che JAX possa accedere alla TPU ed eseguire operazioni di base:

Avvia l'interprete Python 3:

(vm)$ python3
>>> import jax

Mostra il numero di core TPU disponibili:

>>> jax.device_count()

Viene visualizzato il numero di core TPU. Se utilizzi una TPU v4, dovrebbe essere 4. Se utilizzi una TPU v2 o v3, il valore deve essere 8.

Esegui un calcolo semplice:

>>> jax.numpy.add(1, 1)

Viene visualizzato il risultato dell'addizione di NumPy:

Output del comando:

Array(2, dtype=int32, weak_type=true)

Esci dall'interprete Python:

>>> exit()

Esecuzione del codice JAX su una VM TPU

Ora puoi eseguire qualsiasi codice JAX. Gli esempi di flax sono un ottimo punto di partenza per eseguire modelli di ML standard in JAX. Ad esempio, per addestrare una rete convoluzionale MNIST di base:

  1. Installa le dipendenze degli esempi di Flax

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. Installa FLAX

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. Esegui lo script di addestramento FLAX MNIST

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

Lo script scarica il set di dati e avvia l'addestramento. L'output dello script dovrebbe essere simile al seguente:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Esegui la pulizia

Al termine dell'utilizzo della VM TPU, segui questi passaggi per ripulire le risorse.

  1. Se non l'hai ancora fatto, disconnetti dall'istanza Compute Engine:

    (vm)$ exit
  2. Elimina la Cloud TPU.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central1-a
  3. Verifica che le risorse siano state eliminate eseguendo il seguente comando. Assicurati che la TPU non sia più inclusa nell'elenco. L'eliminazione può richiedere qualche minuto.

    $ gcloud compute tpus tpu-vm list \
      --zone=us-central1-a

Note sul rendimento

Di seguito sono riportati alcuni dettagli importanti particolarmente pertinenti per l'utilizzo delle TPU in JAX.

Spaziatura interna

Una delle cause più comuni di prestazioni lente sulle TPU è l'introduzione di spaziatura involontaria:

  • Gli array in Cloud TPU sono suddivisi in riquadri. Ciò comporta l'aggiunta di spazi iniziali a una delle dimensioni in modo da ottenere un multiplo di 8 e a un'altra dimensione in modo da ottenere un multiplo di 128.
  • L'unità di moltiplicazione delle matrici ha il rendimento migliore con coppie di matrici di grandi dimensioni che riducono al minimo la necessità di spaziatura interna.

Tipo di dati bfloat16

Per impostazione predefinita, la moltiplicazione di matrici in JAX su TPU utilizza bfloat16 con accumulo float32. Questo può essere controllato con l'argomento precision sulle chiamate alle funzioni jax.numpy pertinenti (matmul, dot, einsum e così via). In particolare:

  • precision=jax.lax.Precision.DEFAULT: utilizza la precisione bfloat16 mista (più veloce)
  • precision=jax.lax.Precision.HIGH: utilizza più passaggi MXU per ottenere una maggiore precisione
  • precision=jax.lax.Precision.HIGHEST: utilizza ancora più passaggi MXU per ottenere una precisione completa di float32

JAX aggiunge anche il tipo di dati bfloat16, che puoi utilizzare per eseguire il casting esplicito degli array in bfloat16, ad esempio,jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Eseguire JAX in un Colab

Quando esegui il codice JAX in un blocco note di Colab, Colab crea automaticamente un node TPU legacy. I nodi TPU hanno un'architettura diversa. Per maggiori informazioni, consulta la pagina Architettura di sistema.

Passaggi successivi

Per ulteriori informazioni su Cloud TPU, consulta: