Esegui un calcolo su una VM Cloud TPU utilizzando JAX

Questo documento fornisce una breve introduzione al lavoro con JAX e Cloud TPU.

Prima di seguire questa guida rapida, devi creare un account per Google Cloud Platform, installare Google Cloud CLI e configurare il comando gcloud. Per saperne di più, vedi Configurare un account e un progetto Cloud TPU.

Installa Google Cloud CLI

Google Cloud CLI contiene strumenti e librerie per interagire con i prodotti e i servizi Google Cloud. Per maggiori informazioni, consulta Installare Google Cloud CLI.

Configura il comando gcloud

Esegui i comandi seguenti per configurare gcloud in modo da utilizzare il progetto Google Cloud e installare i componenti necessari per l'anteprima della VM TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

Abilita l'API Cloud TPU

  1. Abilita l'API Cloud TPU utilizzando il seguente comando gcloud in Cloud Shell. (puoi anche abilitarlo dalla console Google Cloud).

    $ gcloud services enable tpu.googleapis.com
    
  2. Esegui questo comando per creare un'identità di servizio.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Crea una VM Cloud TPU con gcloud

Con le VM Cloud TPU, il modello e il codice vengono eseguiti direttamente sulla macchina host della TPU. Accedi direttamente all'host TPU tramite SSH. Puoi eseguire codice arbitrario, installare pacchetti, visualizzare i log e il codice di debug direttamente sull'host TPU.

  1. Crea la VM TPU eseguendo il comando seguente da Cloud Shell o dal terminale del computer in cui è installato Google Cloud CLI.

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base
    

    Campi obbligatori

    zone
    La zona in cui prevedi di creare la Cloud TPU.
    accelerator-type
    Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per maggiori informazioni sui tipi di acceleratori supportati per ogni versione di TPU, consulta Versioni TPU.
    version
    La versione software di Cloud TPU. Per tutti i tipi di TPU, utilizza tpu-ubuntu2204-base.

Connettiti alla VM Cloud TPU

Accedi alla VM TPU utilizzando il seguente comando:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central2-b

Campi obbligatori

tpu_name
Il nome della VM TPU a cui ti stai connettendo.
zone
La zona in cui hai creato la Cloud TPU.

Installa JAX sulla VM Cloud TPU

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Controllo del sistema

Verifica che JAX possa accedere alla TPU e possa eseguire operazioni di base:

Avvia l'interprete Python 3:

(vm)$ python3
>>> import jax

Visualizza il numero di core TPU disponibili:

>>> jax.device_count()

Il numero di core TPU viene visualizzato. Se utilizzi una TPU v4, dovrebbe essere 4. Se utilizzi una TPU v2 o v3, dovrebbe essere 8.

Esegui un semplice calcolo:

>>> jax.numpy.add(1, 1)

Viene visualizzato il risultato dell'aggiunta di numpy:

Output dal comando:

Array(2, dtype=int32, weak_type=true)

Esci dall'interprete Python:

>>> exit()

Esecuzione di codice JAX su una VM TPU

Ora puoi eseguire qualsiasi codice JAX che vuoi. Gli esempi di lino sono un ottimo punto di partenza per eseguire modelli ML standard in JAX. Ad esempio, per addestrare una rete convoluzionale MNIST di base:

  1. Installa le dipendenze degli esempi Flax

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
    
  2. Installa FLAX

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
    
  3. Esegui lo script di addestramento FLAX MNIST

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

Lo script scarica il set di dati e avvia l'addestramento. L'output dello script dovrebbe essere simile a questo:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Esegui la pulizia

Al termine delle operazioni della VM TPU, segui questi passaggi per la pulizia delle risorse.

  1. Disconnettiti dall'istanza di Compute Engine, se non lo hai già fatto:

    (vm)$ exit
    
  2. Elimina la tua Cloud TPU.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central2-b
    
  3. Verifica che le risorse siano state eliminate eseguendo questo comando. Assicurati che la tua TPU non sia più elencata. L'eliminazione può richiedere qualche minuto.

    $ gcloud compute tpus tpu-vm list \
      --zone=us-central2-b
    

Note sulle prestazioni

Di seguito sono riportati alcuni dettagli importanti particolarmente importanti per l'utilizzo delle TPU in JAX.

Spaziatura interna

Una delle cause più comuni di lentezza delle prestazioni sulle TPU è l'introduzione di spaziatura interna involontaria:

  • Gli array in Cloud TPU sono affiancati. Ciò comporta il riempimento di una delle dimensioni fino a un multiplo di 8 e una dimensione diversa a un multiplo di 128.
  • L'unità di moltiplicazione matriciale funziona meglio con coppie di matrici di grandi dimensioni che riducono al minimo la necessità di spaziatura.

bfloat16 dtype

Per impostazione predefinita, la moltiplicazione della matrice in JAX sulle TPU utilizza bfloat16 con l'accumulo con float32. Può essere controllato con l'argomento precisione sulle chiamate di funzione jax.numpy pertinenti (matmul, punto, einsum e così via). In particolare:

  • precision=jax.lax.Precision.DEFAULT: utilizza la precisione bfloat16 mista (più veloce)
  • precision=jax.lax.Precision.HIGH: utilizza più pass MXU per ottenere una maggiore precisione
  • precision=jax.lax.Precision.HIGHEST: utilizza ancora più passaggi MXU per raggiungere la precisione float32

JAX aggiunge anche il comando dtype bfloat16, che puoi usare per trasmettere esplicitamente array a bfloat16, ad esempio jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Esecuzione di JAX in un Colab

Quando esegui codice JAX in un blocco note di Colab, Colab crea automaticamente un nodo TPU legacy. I nodi TPU hanno un'architettura diversa. Per ulteriori informazioni, consulta Architettura di sistema.

Passaggi successivi

Per ulteriori informazioni su Cloud TPU, vedi: