Exécuter du code PyTorch sur des tranches de pod TPU

PyTorch/XLA nécessite que toutes les VM TPU puissent accéder au code et aux données du modèle. Vous pouvez utiliser un script de démarrage pour télécharger le logiciel nécessaire à la distribution des données du modèle à toutes les VM TPU.

Si vous connectez vos VM TPU à un cloud privé virtuel (VPC), vous devez ajouter une règle de pare-feu dans votre projet pour autoriser l'entrée sur les ports 8470 à 8479. Pour en savoir plus sur l'ajout de règles de pare-feu, consultez la page Utiliser des règles de pare-feu.

Configurer votre environnement

  1. Dans Cloud Shell, exécutez la commande suivante pour vous assurer que vous exécutez la version actuelle de gcloud:

    $ gcloud components update
    

    Si vous devez installer gcloud, utilisez la commande suivante:

    $ sudo apt install -y google-cloud-sdk
  2. Créez des variables d'environnement:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32
    

Créez la VM TPU.

$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}

Configurer et exécuter le script d'entraînement

  1. Ajoutez votre certificat SSH à votre projet:

    ssh-add ~/.ssh/google_compute_engine
    
  2. Installer PyTorch/XLA sur tous les nœuds de calcul de VM TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="
      pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
    
  3. Cloner XLA sur tous les nœuds de calcul de VM TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="git clone -b r2.3 https://github.com/pytorch/xla.git"
    
  4. Exécuter le script d'entraînement sur tous les nœuds de calcul

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all \
      --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py  \
      --fake_data \
      --model=resnet50  \
      --num_epochs=1 2>&1 | tee ~/logs.txt"
      

    L'entraînement prend environ cinq minutes. Une fois l'opération terminée, un message semblable au suivant doit s'afficher :

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

Effectuer un nettoyage

Lorsque vous avez fini d'utiliser votre VM TPU, procédez comme suit pour nettoyer vos ressources.

  1. Déconnectez-vous de Compute Engine:

    (vm)$ exit
    
  2. Vérifiez que les ressources ont bien été supprimées en exécutant la commande suivante. Assurez-vous que votre TPU n'est plus répertorié. La suppression peut prendre plusieurs minutes.

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a