Esegui un calcolo su una VM Cloud TPU utilizzando JAX

Esegui un calcolo su una VM Cloud TPU utilizzando JAX

Questo documento fornisce una breve introduzione all'uso di JAX e Cloud TPU.

Prima di seguire questa guida rapida, devi creare un account Google Cloud Platform, installare lGoogle Cloud CLI e configurare il comando gcloud. Per saperne di più, vedi Configurare un account e un progetto Cloud TPU.

Installa Google Cloud CLI

Google Cloud CLI contiene strumenti e librerie per interagire con i prodotti e i servizi Google Cloud. Per saperne di più, consulta Installazione di Google Cloud CLI.

Configura il comando gcloud

Esegui i comandi seguenti per configurare gcloud in modo da utilizzare il tuo progetto Google Cloud e installare i componenti necessari per l'anteprima della VM TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

Abilita l'API Cloud TPU

  1. Abilita l'API Cloud TPU utilizzando il comando gcloud seguente in Cloud Shell. Puoi anche abilitarlo dalla console Google Cloud.

    $ gcloud services enable tpu.googleapis.com
    
  2. Esegui questo comando per creare un'identità di servizio.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Crea una VM Cloud TPU con gcloud

Con le VM Cloud TPU, il modello e il codice vengono eseguiti direttamente sulla macchina host TPU. Accedi tramite SSH direttamente all'host TPU. Puoi eseguire codice arbitrario, installare pacchetti, visualizzare i log ed eseguire il debug del codice direttamente sull'host TPU.

  1. Crea la tua VM TPU eseguendo il comando seguente da un dispositivo Google Cloud Shell o dal terminale del computer su cui è installato Google Cloud CLI.

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-8 \
    --version=tpu-vm-v4-base
    

    Campi obbligatori

    zone
    La zona in cui prevedi di creare la tua Cloud TPU.
    accelerator-type
    Il tipo di Cloud TPU da creare.
    version
    La versione del software Cloud TPU. Per le TPU v2 e v3, utilizza tpu-vm-base. Per le TPU v4, utilizza tpu-vm-v4-base.

Connettiti alla tua VM Cloud TPU

Accedi tramite SSH alla VM TPU utilizzando il seguente comando:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central2-b

Campi obbligatori

tpu_name
Il nome della VM TPU a cui ti stai connettendo.
zone
La zona in cui hai creato il tuo Cloud TPU.

Installa JAX sulla tua VM Cloud TPU

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Controllo del sistema

Verifica che JAX possa accedere alla TPU e possa eseguire operazioni di base:

Avvia l'interprete di Python 3:

(vm)$ python3
>>> import jax

Mostra il numero di core TPU disponibili:

>>> jax.device_count()

Viene visualizzato il numero di core TPU. Se utilizzi un TPU v4, questo dovrebbe essere 4. Se utilizzi un TPU v2 o v3, questo deve essere 8.

Esegui un calcolo semplice:

>>> jax.numpy.add(1, 1)

Viene visualizzato il risultato dell'aggiunta di numeri:

Output del comando:

Array(2, dtype=int32, weak_type=true)

Esci dall'interprete Python:

>>> exit()

Esecuzione di codice JAX su una VM TPU

Ora puoi eseguire qualsiasi codice JAX. Gli esempi di lino sono un ottimo punto di partenza per l'esecuzione dei modelli di ML standard in JAX. Ad esempio, per addestrare una rete convoluzionale MNIST di base:

  1. Installa le dipendenze di esempio di Flax

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
    
  2. Installa FLAX

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
    
  3. Esegui lo script di addestramento FLAX MNIST

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

Lo script scarica il set di dati e inizia l'addestramento. L'output dello script dovrebbe essere simile al seguente:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Esegui la pulizia

Quando hai finito di utilizzare la VM TPU, segui questi passaggi per ripulire le risorse.

  1. Disconnettiti dall'istanza di Compute Engine, se non l'hai ancora fatto:

    (vm)$ exit
    
  2. Elimina Cloud TPU.

    $ gcloud compute tpus tpu-vm delete
    tpu-name \
      --zone=us-central2-b
    
  3. Verifica che le risorse siano state eliminate eseguendo il comando seguente. Assicurati che il tuo TPU non sia più presente nell'elenco. L'eliminazione può richiedere qualche minuto.

    $ gcloud compute tpus tpu-vm list \
      --zone=us-central2-b
    

Note sul rendimento

Di seguito sono riportati alcuni dettagli importanti particolarmente pertinenti per l'utilizzo delle TPU in JAX.

Spaziatura interna

Una delle cause più comuni per le prestazioni ridotte sulle TPU è l'introduzione di spaziatura interna involontaria:

  • Gli array in Cloud TPU sono suddivisi in riquadri. Questo comporta il riempimento di una delle dimensioni a un multiplo di 8 e di una dimensione diversa a un multiplo di 128.
  • L'unità di moltiplicazione delle matrici funziona meglio con coppie di grandi matrici che riducono al minimo la necessità di riempimento.

dtype b floating16

Per impostazione predefinita, la moltiplicazione delle matrici in JAX su TPU utilizza b floating16 con accumulo fluttuante 32. Questo può essere controllato con l'argomento della precisione sulle chiamate delle funzioni jax.numpy pertinenti (matmul, punto, einsum ecc.). In particolare:

  • precision=jax.lax.Precision.DEFAULT: utilizza la precisione della combinazione di dati bFlo16 (più veloce)
  • precision=jax.lax.Precision.HIGH: utilizza più pass MXU per ottenere una maggiore precisione
  • precision=jax.lax.Precision.HIGHEST: utilizza un numero ancora maggiore di pass MXU per raggiungere il massimo della precisione in virgola mobile 32

JAX aggiunge anche il dtype b floating16, che puoi utilizzare per trasmettere gli array in modo esplicito a bfloat16, ad esempio jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Esecuzione di JAX in Colab

Quando esegui il codice JAX in un blocco note Colab, Colab crea automaticamente un nodo TPU legacy. I nodi TPU hanno un'architettura diversa. Per ulteriori informazioni, consulta Architettura di sistema.

Passaggi successivi

Per ulteriori informazioni su Cloud TPU, vedi: