Esecuzione di un calcolo su una VM Cloud TPU utilizzando JAX
Questo documento fornisce una breve introduzione all'utilizzo di JAX e Cloud TPU.
Prima di seguire questa guida rapida, devi creare un account Google Cloud, installare Google Cloud CLI e configurare il comando gcloud
.
Per saperne di più, consulta Configurare un account e un progetto Cloud TPU.
Installa Google Cloud CLI
Google Cloud CLI contiene strumenti e librerie per interagire con i prodotti e i servizi Google Cloud. Per ulteriori informazioni, vedi Installazione di Google Cloud CLI.
Configura il comando gcloud
Esegui i seguenti comandi per configurare gcloud
in modo che utilizzi il tuo progetto Google Cloud e installa i componenti necessari per l'anteprima della VM TPU.
$ gcloud config set account your-email-account $ gcloud config set project your-project-id
Abilita l'API Cloud TPU
Abilita l'API Cloud TPU utilizzando il seguente comando
gcloud
in Cloud Shell. Puoi anche attivarla dalla console Google Cloud.$ gcloud services enable tpu.googleapis.com
Esegui questo comando per creare un'identità di servizio.
$ gcloud beta services identity create --service tpu.googleapis.com
Crea una VM Cloud TPU con gcloud
Con le VM Cloud TPU, modello e codice vengono eseguiti direttamente sull'host TPU in una macchina virtuale. Accedi direttamente tramite SSH all'host della TPU. Puoi eseguire codice arbitrario, installare pacchetti, visualizzare i log e eseguire il debug del codice direttamente sull'host TPU.
Crea la tua VM TPU eseguendo questo comando da Cloud Shell oppure il terminale del computer su cui Google Cloud CLI è già installato.
(vm)$ gcloud compute tpus tpu-vm create tpu-name \ --zone=us-central1-a \ --accelerator-type=v3-8 \ --version=tpu-ubuntu2204-base
Campi obbligatori
zone
- La zona in cui prevedi di creare la tua Cloud TPU.
accelerator-type
- Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni sui tipi di acceleratori supportati per ogni versione di TPU, vedi Versioni TPU.
version
- La versione software di Cloud TPU. Per tutti i tipi di TPU, utilizza
tpu-ubuntu2204-base
.
Connettiti alla VM Cloud TPU
Accedi tramite SSH alla VM TPU utilizzando il seguente comando:
$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central1-a
Campi obbligatori
tpu_name
- Il nome della VM TPU a cui ti stai connettendo.
zone
- La zona in cui hai creato la Cloud TPU.
Installa JAX sulla tua VM Cloud TPU
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Controllo del sistema
Verifica che JAX possa accedere alla TPU e possa eseguire le operazioni di base:
Avvia l'interprete Python 3:
(vm)$ python3
>>> import jax
Visualizza il numero di core TPU disponibili:
>>> jax.device_count()
Viene visualizzato il numero di core TPU. Se utilizzi una TPU v4, dovrebbe essere
4
. Se utilizzi una TPU v2 o v3, il valore deve essere 8
.
Esegui un semplice calcolo:
>>> jax.numpy.add(1, 1)
Viene visualizzato il risultato dell'addizione di NumPy:
Output dal comando:
Array(2, dtype=int32, weak_type=true)
Esci dall'interprete Python:
>>> exit()
Esecuzione del codice JAX su una VM TPU
Ora puoi eseguire qualsiasi codice JAX. Gli esempi di lino sono un ottimo punto di partenza per eseguire modelli ML standard in JAX. Ad esempio, per addestrare una rete convoluzionale MNIST di base:
Installa le dipendenze degli esempi di Flax
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Installa FLAX
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Esegui lo script di addestramento MNIST FLAX
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
Lo script scarica il set di dati e avvia l'addestramento. L'output dello script dovrebbe ha questo aspetto:
0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
Esegui la pulizia
Al termine dell'utilizzo della VM TPU, segui questi passaggi per ripulire le risorse.
Se non l'hai ancora fatto, disconnetti dall'istanza Compute Engine:
(vm)$ exit
Elimina il tuo Cloud TPU.
$ gcloud compute tpus tpu-vm delete tpu-name \ --zone=us-central1-a
Verifica che le risorse siano state eliminate eseguendo questo comando. Marca assicurati che la tua TPU non sia più elencata. L'eliminazione può richiedere qualche minuto.
$ gcloud compute tpus tpu-vm list \ --zone=us-central1-a
Note sulle prestazioni
Di seguito sono riportati alcuni dettagli importanti particolarmente pertinenti per l'utilizzo delle TPU in JAX.
Spaziatura interna
Una delle cause più comuni del rallentamento delle prestazioni sulle TPU è l'introduzione spaziatura interna involontaria:
- Gli array in Cloud TPU sono suddivisi in riquadri. Ciò comporta la spaziatura interna in uno dei dimensioni a un multiplo di 8 e una dimensione diversa a un multiplo di 128.
- L'unità di moltiplicazione matriciale ha il rendimento migliore con coppie di matrici grandi per ridurre al minimo la necessità di spaziatura interna.
Tipo di dati bfloat16
Per impostazione predefinita, la moltiplicazione di matrici in JAX su TPU utilizza bfloat16 con accumulo float32. Questo può essere controllato con l'argomento di precisione chiamate di funzione jax.numpy pertinenti (matmul, punto, einsum e così via). In particolare:
precision=jax.lax.Precision.DEFAULT
: utilizza bfloat16 misto precisione (più veloce)precision=jax.lax.Precision.HIGH
: utilizza più passaggi MXU per ottenere una maggiore precisioneprecision=jax.lax.Precision.HIGHEST
: utilizza ancora più passaggi MXU per ottenere una precisione completa di float32
JAX aggiunge anche il dtype bfloat16, che puoi usare per trasmettere gli array in modo esplicito
bfloat16
, ad esempio,
jax.numpy.array(x, dtype=jax.numpy.bfloat16)
.
Esecuzione di JAX in un Colab
Quando esegui codice JAX in un blocco note di Colab, Colab crea automaticamente Nodo TPU. I nodi TPU hanno un'architettura diversa. Per ulteriori informazioni, vedi Architettura di sistema.
Passaggi successivi
Per ulteriori informazioni su Cloud TPU, vedi:
- Eseguire il codice JAX nelle sezioni di pod TPU
- Gestisci le TPU
- Architettura di sistema di Cloud TPU