Esegui il codice PyTorch sulle sezioni di pod di TPU

PyTorch/XLA richiede che tutte le VM TPU siano in grado di accedere al codice e ai dati del modello. Puoi utilizzare uno script di avvio per scaricare il software necessario per distribuire i dati del modello a tutte le VM TPU.

Se connetti le VM TPU a un Virtual Private Cloud (VPC), devi aggiungere una regola firewall nel progetto per consentire il traffico in entrata per le porte 8470 - 8479. Per maggiori informazioni sull'aggiunta di regole firewall, consulta Utilizzo delle regole firewall

configura l'ambiente

  1. In Cloud Shell, esegui questo comando per assicurarti di eseguire la versione attuale di gcloud:

    $ gcloud components update
    

    Se devi installare gcloud, utilizza il seguente comando:

    $ sudo apt install -y google-cloud-sdk
  2. Crea alcune variabili di ambiente:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32
    

Crea la VM TPU

$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}

Configurare ed eseguire lo script di addestramento

  1. Aggiungi il certificato SSH al progetto:

    ssh-add ~/.ssh/google_compute_engine
    
  2. Installa PyTorch/XLA su tutti i worker TPU per VM

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="
      pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
    
  3. Clona XLA su tutti i worker VM TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="git clone -b r2.2 https://github.com/pytorch/xla.git"
    
  4. Esegui lo script di addestramento su tutti i worker

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all \
      --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py  \
      --fake_data \
      --model=resnet50  \
      --num_epochs=1 2>&1 | tee ~/logs.txt"
      

    L'addestramento richiede circa 5 minuti. Al termine, dovresti visualizzare un messaggio simile al seguente:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

Esegui la pulizia

Al termine delle operazioni della VM TPU, segui questi passaggi per la pulizia delle risorse.

  1. Disconnettiti da Compute Engine:

    (vm)$ exit
    
  2. Verifica che le risorse siano state eliminate eseguendo questo comando. Assicurati che la tua TPU non sia più elencata. L'eliminazione può richiedere qualche minuto.

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a