Exécuter du code PyTorch sur des tranches de pod TPU
PyTorch/XLA nécessite que toutes les VM TPU puissent accéder au code et aux données du modèle. Vous pouvez utiliser un script de démarrage pour télécharger le logiciel nécessaire à la distribution des données du modèle à toutes les VM TPU.
Si vous connectez vos VM TPU à un cloud privé virtuel (VPC), vous devez ajouter une règle de pare-feu dans votre projet pour autoriser l'entrée sur les ports 8470 à 8479. Pour en savoir plus sur l'ajout de règles de pare-feu, consultez la page Utiliser des règles de pare-feu.
Configurer votre environnement
Dans Cloud Shell, exécutez la commande suivante pour vous assurer que vous exécutez la version actuelle de
gcloud
:$ gcloud components update
Si vous devez installer
gcloud
, utilisez la commande suivante:$ sudo apt install -y google-cloud-sdk
Créez des variables d'environnement:
$ export PROJECT_ID=project-id $ export TPU_NAME=tpu-name $ export ZONE=us-central2-b $ export RUNTIME_VERSION=tpu-ubuntu2204-base $ export ACCELERATOR_TYPE=v4-32
Créez la VM TPU.
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}
Configurer et exécuter le script d'entraînement
Ajoutez votre certificat SSH à votre projet:
ssh-add ~/.ssh/google_compute_engine
Installer PyTorch/XLA sur tous les nœuds de calcul de VM TPU
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all --command=" pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
Cloner XLA sur tous les nœuds de calcul de VM TPU
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all --command="git clone -b r2.3 https://github.com/pytorch/xla.git"
Exécuter le script d'entraînement sur tous les nœuds de calcul
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py \ --fake_data \ --model=resnet50 \ --num_epochs=1 2>&1 | tee ~/logs.txt"
L'entraînement prend environ cinq minutes. Une fois l'opération terminée, un message semblable au suivant doit s'afficher :
Epoch 1 test end 23:49:15, Accuracy=100.00 10.164.0.11 [0] Max Accuracy: 100.00%
Effectuer un nettoyage
Lorsque vous avez fini d'utiliser votre VM TPU, procédez comme suit pour nettoyer vos ressources.
Déconnectez-vous de Compute Engine:
(vm)$ exit
Vérifiez que les ressources ont bien été supprimées en exécutant la commande suivante. Assurez-vous que votre TPU n'est plus répertorié. La suppression peut prendre plusieurs minutes.
$ gcloud compute tpus tpu-vm list \ --zone europe-west4-a