Addestramento di Resnet50 su Cloud TPU con PyTorch


Questo tutorial mostra come addestrare il modello ResNet-50 su un dispositivo Cloud TPU con PyTorch. Puoi applicare lo stesso pattern ad altri modelli di classificazione delle immagini ottimizzati per TPU che utilizzano PyTorch e il set di dati ImageNet.

Il modello in questo tutorial si basa sul Deep Residual Learning for Image Recognition, che introduce per prima l'architettura di rete residua (ResNet). Il tutorial utilizza la variante a 50 livelli, ResNet-50, e dimostra l'addestramento del modello utilizzando PyTorch/XLA.

Obiettivi

  • Prepara il set di dati.
  • Eseguire il job di addestramento.
  • Verifica i risultati di output.

Costi

In questo documento vengono utilizzati i seguenti componenti fatturabili di Google Cloud:

  • Compute Engine
  • Cloud TPU

Per generare una stima dei costi in base all'utilizzo previsto, utilizza il Calcolatore prezzi. I nuovi utenti di Google Cloud possono essere idonei a una prova senza costi aggiuntivi.

Prima di iniziare

Prima di iniziare questo tutorial, verifica che il tuo progetto Google Cloud sia configurato correttamente.

  1. Accedi al tuo account Google Cloud. Se non conosci Google Cloud, crea un account per valutare le prestazioni dei nostri prodotti in scenari reali. I nuovi clienti ricevono anche 300 $di crediti gratuiti per l'esecuzione, il test e il deployment dei carichi di lavoro.
  2. Nella console di Google Cloud Console, nella pagina del selettore dei progetti, seleziona o crea un progetto Google Cloud.

    Vai al selettore progetti

  3. Assicurati che la fatturazione sia attivata per il tuo progetto Google Cloud.

  4. Nella console di Google Cloud Console, nella pagina del selettore dei progetti, seleziona o crea un progetto Google Cloud.

    Vai al selettore progetti

  5. Assicurati che la fatturazione sia attivata per il tuo progetto Google Cloud.

  6. Questa procedura dettagliata utilizza componenti fatturabili di Google Cloud. Consulta la pagina Prezzi di Cloud TPU per stimare i costi. Assicurati di ripulire le risorse che crei quando hai finito di utilizzarle per evitare addebiti inutili.

Crea una VM TPU

  1. Apri una finestra di Cloud Shell.

    Apri Cloud Shell

  2. Crea una VM TPU

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central2-b \
    --project=your-project
    
  3. Connettiti alla VM TPU tramite SSH:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central2-b
    
  4. Installa PyTorch/XLA sulla VM TPU:

    (vm)$ pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
    
  5. Clona il repo github PyTorch/XLA.

    (vm)$ git clone --depth=1 --branch r2.1 https://github.com/pytorch/xla.git
    
  6. Eseguire lo script di addestramento con dati falsi

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
    

Se sei in grado di addestrare il modello utilizzando dati falsi, puoi provare ad addestrare il modello su dati reali, come ImageNet. Per istruzioni su come scaricare ImageNet, consulta la pagina sul download di ImageNet. Nel comando script di addestramento, il flag --datadir specifica la posizione del set di dati su cui eseguire l'addestramento. Il comando seguente presuppone che il set di dati ImageNet si trovi in ~/imagenet.

   (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py  --datadir=~/imagenet --batch_size=256 --num_epochs=1
   

Esegui la pulizia

Per evitare che al tuo Account Google Cloud vengano addebitati costi relativi alle risorse utilizzate in questo tutorial, elimina il progetto che contiene le risorse oppure mantieni il progetto ed elimina le singole risorse.

  1. Disconnettiti dalla VM TPU:

    (vm) $ exit
    

    Ora il prompt dovrebbe essere username@projectname, a indicare che ti trovi in Cloud Shell.

  2. Elimina la VM TPU.

    $ gcloud compute tpus tpu-vm delete resnet50-tutorial \
       --zone=us-central2-b
    

Passaggi successivi

Prova i Colab PyTorch: