Esegui codice PyTorch sulle sezioni di pod di TPU

PyTorch/XLA richiede che tutte le VM TPU possano accedere al codice e ai dati del modello. Puoi utilizzare uno script di avvio per scaricare il software necessario per distribuire i dati del modello in tutte le VM TPU.

Se colleghi le VM TPU a un Virtual Private Cloud (VPC) devi aggiungere una regola firewall nel progetto per consentire il traffico in entrata per le porte 8470 - 8479. Per ulteriori informazioni sull'aggiunta di regole firewall, consulta Utilizzo delle regole firewall

Configura l'ambiente

  1. In Cloud Shell, esegui questo comando per assicurarti di essere con la versione corrente di gcloud:

    $ gcloud components update
    

    Se devi installare gcloud, usa il seguente comando:

    $ sudo apt install -y google-cloud-sdk
  2. Crea alcune variabili di ambiente:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32
    

Crea la VM TPU

$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}

Configura ed esegui lo script di addestramento

  1. Aggiungi il certificato SSH al progetto:

    ssh-add ~/.ssh/google_compute_engine
    
  2. Installa PyTorch/XLA su tutti i worker VM TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="
      pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
    
  3. Clona XLA su tutti i worker VM TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="git clone -b r2.3 https://github.com/pytorch/xla.git"
    
  4. Esegui lo script di addestramento su tutti i worker

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all \
      --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py  \
      --fake_data \
      --model=resnet50  \
      --num_epochs=1 2>&1 | tee ~/logs.txt"
      

    L'addestramento richiede circa 5 minuti. Al termine, dovresti vedere una simile al seguente:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

Esegui la pulizia

Al termine delle operazioni con la VM TPU, segui questi passaggi per pulire le risorse.

  1. Disconnettiti da Compute Engine:

    (vm)$ exit
    
  2. Verifica che le risorse siano state eliminate eseguendo questo comando. Marca assicurati che la tua TPU non sia più elencata. L'eliminazione può richiedere qualche minuto.

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a