Exécuter du code PyTorch sur des tranches TPU

Avant d'exécuter les commandes de ce document, assurez-vous d'avoir suivi les instructions de la section Configurer un compte et un projet Cloud TPU.

Une fois que votre code PyTorch s'exécute sur une seule VM TPU, vous pouvez augmenter la capacité en l'exécutant sur une tranche TPU. Les tranches de TPU sont des cartes de TPU interconnectées sur des connexions réseau haut débit dédiées. Ce document est une introduction à l'exécution de code PyTorch sur des tranches de TPU.

Créer une tranche Cloud TPU

  1. Définissez des variables d'environnement pour faciliter l'utilisation des commandes.

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=europe-west4-b
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    Descriptions des variables d'environnement

    PROJECT_ID
    L'ID de votre Google Cloud projet.
    ACCELERATOR_TYPE
    Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez la section Versions de TPU.
    ZONE
    Zone dans laquelle vous prévoyez de créer votre Cloud TPU.
    RUNTIME_VERSION
    Version du logiciel Cloud TPU.
    TPU_NAME
    Nom attribué par l'utilisateur à votre Cloud TPU.
  2. Créez votre VM TPU en exécutant la commande suivante:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --version=${RUNTIME_VERSION}

Installer PyTorch/XLA sur votre tranche

Après avoir créé la tranche TPU, vous devez installer PyTorch sur tous les hôtes de la tranche TPU. Pour ce faire, utilisez la commande gcloud compute tpus tpu-vm ssh avec les paramètres --worker=all et --commamnd.

Si les commandes suivantes échouent en raison d'une erreur de connexion SSH, cela peut être dû au fait que les VM TPU n'ont pas d'adresses IP externes. Pour accéder à une VM TPU sans adresse IP externe, suivez les instructions de la section Se connecter à une VM TPU sans adresse IP publique.

  1. Installez PyTorch/XLA sur tous les nœuds de travail de la VM TPU:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --worker=all \
        --command="pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
  2. Cloner XLA sur tous les nœuds de travail de VM TPU:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --worker=all \
        --command="git clone https://github.com/pytorch/xla.git"

Exécuter un script d'entraînement sur votre tranche TPU

Exécutez le script d'entraînement sur tous les nœuds. Le script d'entraînement utilise une stratégie de partitionnement SPMD (Single Program Multiple Data). Pour en savoir plus sur SPMD, consultez le guide de l'utilisateur de SPMD PyTorch/XLA.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --worker=all \
   --command="PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py  \
   --fake_data \
   --model=resnet50  \
   --num_epochs=1 2>&1 | tee ~/logs.txt"

L'entraînement dure environ 15 minutes. Une fois l'opération terminée, un message semblable au suivant doit s'afficher:

Epoch 1 test end 23:49:15, Accuracy=100.00
     10.164.0.11 [0] Max Accuracy: 100.00%

Effectuer un nettoyage

Lorsque vous avez fini d'utiliser votre VM TPU, procédez comme suit pour nettoyer vos ressources.

  1. Déconnectez-vous de l'instance Cloud TPU, si vous ne l'avez pas déjà fait:

    (vm)$ exit

    Votre invite devrait maintenant être username@projectname, indiquant que vous êtes dans Cloud Shell.

  2. Supprimez vos ressources Cloud TPU.

    $ gcloud compute tpus tpu-vm delete  \
        --zone=${ZONE}
  3. Vérifiez que les ressources ont été supprimées en exécutant la commande gcloud compute tpus tpu-vm list. La suppression peut prendre plusieurs minutes. Le résultat de la commande suivante ne doit inclure aucune des ressources créées dans ce tutoriel:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE}