Guide de l'utilisateur de Cloud TPU PyTorch/XLA

Exécuter des charges de travail de ML avec PyTorch/XLA

Ce guide vous explique comment effectuer un calcul simple sur un TPU v4 à l'aide de PyTorch.

Configuration de base

  1. Créez une VM TPU avec un TPU v4 exécutant l'environnement d'exécution de la VM TPU pour Pytorch 2.0:

      gcloud compute tpus tpu-vm create your-tpu-name \
      --zone=us-central2-b \
      --accelerator-type=v4-8 \
      --version=tpu-vm-v4-pt-2.0
  2. Connectez-vous à la VM TPU à l'aide de SSH:

      gcloud compute tpus tpu-vm ssh your-tpu-name \
      --zone=us-central2-b \
      --accelerator-type=v4-8
  3. Définissez la configuration de l'appareil TPU PJRT ou XRT.

    PJRT

        (vm)$ export PJRT_DEVICE=TPU
     

    XRT

        (vm)$ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
     

  4. Si vous effectuez l'entraînement avec Cloud TPU v4, définissez également la variable d'environnement suivante:

      (vm)$ export TPU_NUM_DEVICES=4

Effectuer un calcul simple

  1. Démarrez l'interpréteur Python sur la VM TPU :

    (vm)$ python3
  2. Importez les packages PyTorch suivants :

    import torch
    import torch_xla.core.xla_model as xm
  3. Saisissez le script suivant :

    dev = xm.xla_device()
    t1 = torch.randn(3,3,device=dev)
    t2 = torch.randn(3,3,device=dev)
    print(t1 + t2)

    Le résultat suivant s'affiche :

    tensor([[-0.2121,  1.5589, -0.6951],
           [-0.7886, -0.2022,  0.9242],
           [ 0.8555, -1.8698,  1.4333]], device='xla:1')
    

Exécuter Resnet sur un TPU individuel

À ce stade, vous pouvez exécuter n'importe quel code PyTorch/XLA. Par exemple, vous pouvez exécuter un modèle ResNet avec des données fictives :

(vm)$ git clone --recursive https://github.com/pytorch/xla.git
(vm)$ python3 xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1

L'exemple ResNet est entraîné pendant une époque et prend environ sept minutes. La commande renvoie un résultat semblable au suivant :

Epoch 1 test end 20:57:52, Accuracy=100.00 Max Accuracy: 100.00%

Une fois l'entraînement ResNet terminé, supprimez la VM TPU.

(vm)$ exit
$ gcloud compute tpus tpu-vm delete tpu-name \
--zone=zone

La suppression peut prendre plusieurs minutes. Vérifiez que les ressources ont été supprimées en exécutant la commande gcloud compute tpus list --zone=${ZONE}.

Rubriques avancées

Pour les modèles comportant des allocations importantes et fréquentes, tcmalloc améliore les performances par rapport à la fonction d'exécution C/C++ malloc. La valeur par défaut de malloc utilisée sur la VM TPU est tcmalloc. Vous pouvez forcer le logiciel de la VM TPU à utiliser la version standard de malloc en désactivant la variable d'environnement LD_PRELOAD:

   (vm)$ unset LD_PRELOAD

Dans les exemples précédents (le calcul simple et ResNet50), le programme PyTorch/XLA démarre le serveur XRT local dans le même processus que l'interpréteur Python. Vous pouvez également choisir de démarrer le service local XRT dans un processus distinct :

(vm)$ python3 -m torch_xla.core.xrt_run_server --port 51011 --restart

L'avantage de cette approche est que le cache de compilation est conservé d'un cycle d'entraînement à un autre. Lorsque vous exécutez le serveur XLA dans un processus distinct, les informations de journalisation côté serveur sont écrites dans /tmp/xrt_server_log.

(vm)$ ls /tmp/xrt_server_log/
server_20210401-031010.log

Profilage des performances des VM TPU

Pour en savoir plus sur le profilage de vos modèles sur une VM TPU, consultez la page Profilage des performances PyTorch XLA.

Exemples de pods TPU PyTorch/XLA

Consultez la page Pod de VM TPU PyTorch pour obtenir des informations de configuration et des exemples d'exécution de PyTorch/XLA sur un pod de VM TPU.

Docker sur une VM TPU

Cette section explique comment exécuter Docker sur une VM TPU avec PyTorch/XLA préinstallé.

Images Docker disponibles

Vous pouvez vous reporter au fichier README de GitHub pour trouver toutes les images Docker de VM TPU disponibles.

Exécuter des images Docker sur une VM TPU

(tpuvm): sudo docker pull gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm
(tpuvm): sudo docker run --privileged  --shm-size 16G --name tpuvm_docker -it -d  gcr.io/tpu-pytorch/xla:nightly_3.8_tpuvm
(tpuvm): sudo docker exec --privileged -it tpuvm_docker /bin/bash
(pytorch) root:/#

Valider libtpu

Pour vérifier que libtpu est installé, exécutez la commande suivante:

(pytorch) root:/# ls /root/anaconda3/envs/pytorch/lib/python3.8/site-packages/ | grep libtpu
Vous devriez obtenir un résultat semblable à celui-ci:
libtpu
libtpu_nightly-0.1.dev20220518.dist-info

Si aucun résultat ne s'affiche, vous pouvez installer manuellement la bibliothèque libtpu correspondant à l'aide de la commande suivante:

(pytorch) root:/# pip install torch_xla[tpuvm]

Vérifier tcmalloc

tcmalloc est le malloc utilisé par défaut sur la VM TPU. Pour en savoir plus, consultez cette section. Cette bibliothèque doit être préinstallée sur les images Docker de VM TPU plus récentes, mais il est toujours préférable de la vérifier manuellement. Vous pouvez exécuter la commande suivante pour vérifier que la bibliothèque est installée.

(pytorch) root:/# echo $LD_PRELOAD
Vous devriez obtenir un résultat semblable à celui-ci:
/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4

Si LD_PRELOAD n'est pas défini, vous pouvez exécuter manuellement:

(pytorch) root:/# sudo apt-get install -y google-perftools
(pytorch) root:/# export LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4"

Valider l'appareil

Vous pouvez vérifier que l'appareil TPU VM est disponible en exécutant la commande suivante:

(pytorch) root:/# ls /dev | grep accel
Vous devriez obtenir les résultats suivants.
accel0
accel1
accel2
accel3

Si aucun résultat ne s'affiche, vous n'avez probablement pas démarré le conteneur avec l'option --privileged.

Exécuter un modèle

Vous pouvez vérifier si l'appareil de VM TPU est disponible en exécutant la commande suivante:

(pytorch) root:/# export XRT_TPU_CONFIG="localservice;0;localhost:51011"
(pytorch) root:/# python3 pytorch/xla/test/test_train_mp_imagenet.py --fake_data --num_epochs 1