Mantieni tutto organizzato con le raccolte Salva e classifica i contenuti in base alle tue preferenze.
Guida rapida: esegui un calcolo su una VM Cloud TPU utilizzando JAX

Esegui un calcolo su una VM Cloud TPU utilizzando JAX

Questo documento fornisce una breve introduzione al lavoro con JAX e Cloud TPU.

Prima di seguire questa guida rapida, devi creare un account Google Cloud Platform, installare l'interfaccia a riga di comando di Google Cloud e configurare il comando gcloud. Per saperne di più, vedi Configurare un account e un progetto Cloud TPU.

Installa Google Cloud CLI

L'interfaccia a riga di comando di Google Cloud contiene strumenti e librerie per interagire con i prodotti e i servizi Google Cloud. Per ulteriori informazioni, consulta la sezione Installazione di Google Cloud CLI.

Configura il comando gcloud

Esegui questi comandi per configurare gcloud in modo da utilizzare il tuo progetto GCP e installare i componenti necessari per l'anteprima della VM TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project project-id

Abilita l'API Cloud TPU

  1. Abilita l'API Cloud TPU utilizzando il seguente comando gcloud in Cloud Shell. Puoi anche abilitarlo da Google Cloud Console.

    $ gcloud services enable tpu.googleapis.com
    
  2. Esegui il comando seguente per creare un'identità del servizio.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Crea una VM Cloud TPU con gcloud

Con le VM Cloud TPU, il modello e il codice vengono eseguiti direttamente sulla macchina host TPU. Puoi utilizzare SSH direttamente nell'host TPU. Puoi eseguire codice arbitrario, installare pacchetti, visualizzare log ed eseguire il debug del codice direttamente sull'host TPU.

  1. Crea la tua VM TPU eseguendo il comando seguente da una piattaforma Cloud Shell di GCP o dal terminale del tuo computer su cui è installato Google Cloud CLI. Sostituisci la variabile --version con tpu-vm-base per le versioni TPU v2 e v3 oppure con tpu-vm-v4-base per le TPU TPU v4.

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone europe-west4-a \
    --accelerator-type v3-8 \
    --version tpu-software-version
    

    Campi obbligatori

    zone
    La zona in cui intendi creare la Cloud TPU.
    accelerator-type
    Il tipo di Cloud TPU da creare.
    version
    Versione software di Cloud TPU. Utilizza "tpu-vm-base" per le versioni TPU v2 e v3. Utilizza "tpu-vm-v4-base" con TPU v4.

Connettiti alla tua VM Cloud TPU

SSH nella VM TPU utilizzando il seguente comando:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone europe-west4-a

Campi obbligatori

tpu_name
Il nome della VM TPU a cui ti connetti.
zone
La zona in cui hai creato Cloud TPU.

Installa JAX sulla tua VM Cloud TPU

(vm)$ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Controllo del sistema

Verifica che tutto sia installato correttamente controllando che JAX visualizzi i core Cloud TPU e possa eseguire le operazioni di base:

Avvia l'interprete di Python 3:

(vm)$ python3
>>> import jax

Visualizza il numero di core TPU disponibili:

>>> jax.device_count()

Viene visualizzato il numero di core TPU, che deve essere 8 se utilizzi un TPU TPU v2 o v3 oppure 4 se utilizzi un TPU v4.

Esegui un semplice calcolo:

>>> jax.numpy.add(1, 1)

Viene visualizzato il risultato dell'aggiunta numpy:

Output del comando:

DeviceArray(2, dtype=int32)

Esci dall'interprete di Python:

>>> exit()

Esecuzione del codice JAX su una VM TPU

Ora puoi eseguire il codice JAX che preferisci. Gli esempi di lino sono un ottimo punto di partenza per l'esecuzione di modelli di machine learning standard in JAX. Ad esempio, per addestrare una rete convoluzionale MNIST di base:

  1. Installare set di dati Tensorflow

    (vm)$ pip install --upgrade clu
    
  2. Installa FLAX.

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user -e flax
    
  3. Esegui lo script di addestramento FLAX MNIST

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

    L'output dello script dovrebbe essere simile al seguente:

    I0726 00:57:51.274136 139632684678208 train.py:146] epoch:  1, train_loss: 0.2423, train_accuracy: 92.96, test_loss: 0.0629, test_accuracy: 97.98
    I0726 00:57:52.741929 139632684678208 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.15, test_loss: 0.0434, test_accuracy: 98.61
    I0726 00:57:54.149238 139632684678208 train.py:146] epoch:  3, train_loss: 0.0417, train_accuracy: 98.73, test_loss: 0.0307, test_accuracy: 98.98
    I0726 00:57:55.570881 139632684678208 train.py:146] epoch:  4, train_loss: 0.0309, train_accuracy: 99.03, test_loss: 0.0273, test_accuracy: 99.13
    I0726 00:57:56.937045 139632684678208 train.py:146] epoch:  5, train_loss: 0.0251, train_accuracy: 99.21, test_loss: 0.0270, test_accuracy: 99.16

Esegui la pulizia

Al termine della VM TPU, segui questi passaggi per eseguire la pulizia delle risorse.

  1. Disconnettiti dall'istanza di Compute Engine, se non lo hai già fatto:

    (vm)$ exit
    
  2. Elimina Cloud TPU.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. Verifica che le risorse siano state eliminate eseguendo il comando seguente. Assicurati che la tua TPU non sia più presente nell'elenco. L'eliminazione può richiedere qualche minuto.

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a
    

Note sul rendimento

Ecco alcuni dettagli importanti che sono particolarmente pertinenti per l'utilizzo delle TPU in JAX.

Spaziatura interna

Una delle cause più comuni delle prestazioni lente sulle TPU è l'introduzione di spaziatura interna involontaria:

  • Gli array in Cloud TPU sono in piastrelle. Questo comporta il riempimento di una delle dimensioni a un multiplo di 8 e una dimensione diversa a un multiplo di 128.
  • L'unità di moltiplicazione delle matrici funziona meglio con coppie di grandi matrici che riducono al minimo la necessità di spaziatura interna.

tipo djet 16

Per impostazione predefinita, la moltiplicazione delle matrici in JAX sulle TPU utilizza bdecimal16 con accumulo di fluttua32. Ciò può essere controllato con l'argomento di precisione sulle chiamate di funzione jax.numpy pertinenti (matmul, punto, einsum e così via). In particolare:

  • precision=jax.lax.Precision.DEFAULT: utilizza una precisione b floating16 mista (più veloce)
  • precision=jax.lax.Precision.HIGH: utilizza più pass MXU per ottenere una precisione più elevata
  • precision=jax.lax.Precision.HIGHEST: utilizza ancora più pass MXU per ottenere la massima precisione di Floatt32

JAX aggiunge anche il tipo dfluttuante 16, che puoi utilizzare per trasmettere gli array in modo esplicito a bfloat16, ad esempio jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Esecuzione di JAX in un Colab

Quando esegui il codice JAX in un blocco note di Colab, Colab crea automaticamente un nodo TPU legacy. I nodi TPU hanno un'architettura diversa. Per ulteriori informazioni, consulta Architettura di sistema.

Passaggi successivi

Per ulteriori informazioni su Cloud TPU, vedi: