Exécuter le code PyTorch sur des tranches de pod TPU

PyTorch/XLA nécessite que toutes les VM TPU puissent accéder au code et aux données du modèle. Vous pouvez utiliser un script de démarrage pour télécharger le logiciel nécessaire. pour distribuer les données du modèle à toutes les VM TPU.

Si vous connectez vos VM TPU à un cloud privé virtuel (VPC), vous devez ajouter une règle de pare-feu dans votre projet pour autoriser l'entrée pour les ports 8470 - 8479. Pour en savoir plus sur l'ajout de règles de pare-feu, consultez Utiliser des règles de pare-feu

Configurer votre environnement

  1. Dans Cloud Shell, exécutez la commande suivante pour vérifier que vous êtes exécutant la version actuelle de gcloud:

    $ gcloud components update
    

    Si vous devez installer gcloud, exécutez la commande suivante:

    $ sudo apt install -y google-cloud-sdk
  2. Créez des variables d'environnement:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32
    

Créez la VM TPU.

$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}

Configurer et exécuter le script d'entraînement

  1. Ajoutez votre certificat SSH à votre projet:

    ssh-add ~/.ssh/google_compute_engine
    
  2. Installer PyTorch/XLA sur tous les nœuds de calcul de VM TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="
      pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
    
  3. Cloner XLA sur tous les nœuds de calcul de VM TPU

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="git clone -b r2.3 https://github.com/pytorch/xla.git"
    
  4. Exécuter le script d'entraînement sur tous les nœuds de calcul

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all \
      --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py  \
      --fake_data \
      --model=resnet50  \
      --num_epochs=1 2>&1 | tee ~/logs.txt"
      

    L'entraînement prend environ cinq minutes. Une fois l'opération terminée, un message semblable au suivant doit s'afficher :

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

Effectuer un nettoyage

Lorsque vous avez fini d'utiliser votre VM TPU, procédez comme suit pour nettoyer vos ressources.

  1. Déconnectez-vous de Compute Engine:

    (vm)$ exit
    
  2. Vérifiez que les ressources ont bien été supprimées en exécutant la commande suivante. Assurez-vous que votre TPU n'est plus répertorié. La suppression peut prendre plusieurs minutes.

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a