Questo tutorial mostra come addestrare il modello ResNet-50 su un dispositivo Cloud TPU con PyTorch. Puoi applicare lo stesso pattern ad altri modelli di classificazione delle immagini ottimizzati per TPU che utilizzano PyTorch e il set di dati ImageNet.
Il modello in questo tutorial si basa sul deep residual learning per il riconoscimento delle immagini, che introduce per la prima volta l'architettura della rete di residui (ResNet). Il tutorial utilizza la variante con 50 livelli, ResNet-50, e mostra l'addestramento del modello utilizzando PyTorch/XLA.
Obiettivi
- Prepara il set di dati.
- Esegui il job di addestramento.
- Verifica i risultati dell'output.
Costi
In questo documento utilizzi i seguenti componenti fatturabili di Google Cloud:
- Compute Engine
- Cloud TPU
Per generare una stima dei costi in base all'utilizzo previsto,
utilizza il Calcolatore prezzi.
Prima di iniziare
Prima di iniziare questo tutorial, controlla che il progetto Google Cloud sia configurato correttamente.
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
Questa procedura dettagliata utilizza i componenti fatturabili di Google Cloud. Consulta la pagina Prezzi per Cloud TPU per stimare i costi. Una volta terminato il loro utilizzo, assicurati di eseguire la pulizia delle risorse create per evitare addebiti superflui.
Crea una VM TPU
Apri una finestra di Cloud Shell.
Crea una VM TPU
gcloud compute tpus tpu-vm create your-tpu-name \ --accelerator-type=v4-8 \ --version=tpu-ubuntu2204-base \ --zone=us-central2-b \ --project=your-project
Connettiti alla VM TPU tramite SSH:
gcloud compute tpus tpu-vm ssh your-tpu-name --zone=us-central2-b
Installa PyTorch/XLA sulla VM TPU:
(vm)$ pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
Clona il repository GitHub di PyTorch/XLA
(vm)$ git clone --depth=1 --branch r2.5 https://github.com/pytorch/xla.git
Esegui lo script di addestramento con dati falsi
(vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
Se riesci ad addestrare il modello utilizzando dati falsi, puoi provare ad addestrare il modello su dati reali, come ImageNet. Per istruzioni su come scaricare ImageNet, consulta
Download di ImageNet. Nel comando dello script di addestramento,
--datadir
specifica la posizione del set di dati su cui eseguire l'addestramento.
Il seguente comando presuppone che il set di dati ImageNet si trovi in ~/imagenet
.
(vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --datadir=~/imagenet --batch_size=256 --num_epochs=1
Esegui la pulizia
Per evitare che al tuo account Google Cloud vengano addebitati costi relativi alle risorse utilizzate in questo tutorial, elimina il progetto che contiene le risorse oppure mantieni il progetto ed elimina le singole risorse.
Disconnettiti dalla VM TPU:
(vm) $ exit
Il tuo prompt dovrebbe ora essere
username@projectname
, a indicare che ti trovi in Cloud Shell.Elimina la VM TPU.
$ gcloud compute tpus tpu-vm delete resnet50-tutorial \ --zone=us-central2-b
Passaggi successivi
- Addestramento dei modelli di diffusione con PyTorch
- Risoluzione dei problemi di PyTorch su TPU
- Documentazione di Pytorch/XLA