Entraîner des modèles PyTorch sur les pods Cloud TPU

Ce tutoriel montre comment faire évoluer l'entraînement de votre modèle depuis un Cloud TPU simple (v2-8 ou v3-8) vers un pod Cloud TPU. Les accélérateurs de Cloud TPU dans un pod TPU sont reliés entre eux par des interconnexions à bande passante très large, ce qui les rend très efficaces pour faire évoluer des tâches d'entraînement.

De plus amples informations sur les offres Cloud TPU sont disponibles sur la page produit Cloud TPU ainsi que dans cette présentation de Cloud TPU.

Le schéma suivant présente une vue d'ensemble de la configuration du cluster distribué. Un groupe d'instances de VM est connecté à un pod TPU. Une VM est requise pour chaque groupe de 8 cœurs de TPU. Les VM alimentent les cœurs de TPU et toutes les tâches d'entraînement s'exécutent sur le pod TPU.

image

Objectifs

  • Configurer un groupe d'instances Compute Engine et un pod Cloud TPU pour s'entraîner avec PyTorch/XLA
  • Exécuter une tâche d'entraînement PyTorch/XLA sur un boîtier Cloud TPU

Avant de commencer

Avant de commencer un entraînement distribué sur des pods Cloud TPU, vérifiez que votre modèle peut être entraîné sur un seul appareil Cloud TPU v2-8 ou v3-8. Si votre modèle présente des problèmes de performances importants sur un seul appareil, consultez les guides de bonnes pratiques et de résolution des problèmes.

Une fois que votre appareil TPU est correctement entraîné, effectuez les étapes suivantes pour configurer et entraîner sur un pod Cloud TPU :

  1. Configurez la commande gcloud.

  2. [Facultatif] Capturez une image du disque de la VM dans une image de VM.

  3. Créez un modèle d'instance à partir d'une image de VM.

  4. Créez un groupe d'instances à partir de votre modèle d'instance.

  5. Connection SSH à votre VM Compute Engine

  6. Vérifier les règles de pare-feu pour autoriser les communications entre VM.

  7. Créer un pod Cloud TPU.

  8. Exécuter un entraînement distribué sur le pod.

  9. Nettoyer

Configurer la commande gcloud

Configurez votre projet GCP avec gcloud :

Créez une variable pour l'ID de votre projet.

export PROJECT_ID=project-id

Définissez votre ID de projet comme projet par défaut dans gcloud.

gcloud config set project ${PROJECT_ID}

Configurez la zone par défaut avec gcloud :

gcloud config set compute/zone us-central1-a

[Facultatif] Capturer une image du disque de la VM

Vous pouvez utiliser l'image disque de la VM utilisée pour entraîner le TPU seul. Elle contient déjà l'ensemble de données, les packages installés, etc. Avant de créer une image, arrêtez la VM à l'aide de la commande gcloud :

gcloud compute instances stop vm-name

Ensuite, créez une image de la VM à l'aide de la commande gcloud :

gcloud compute images create image-name  \
    --source-disk instance-name \
    --source-disk-zone us-central1-a \
    --family=torch-xla \
    --storage-location us-central1

Créer un modèle d'instance à partir d'une image de VM

Créez un modèle d'instance par défaut. Lorsque vous créez un modèle d'instance, vous pouvez utiliser l'image de la VM créée à l'étape ci-dessus OU utiliser l'image publique PyTorch/XLA fournie par Google. Pour créer un modèle d'instance, exécutez la commande gcloud :

gcloud compute instance-templates create instance-template-name \
    --machine-type n1-standard-16 \
    --image-project=${PROJECT_ID} \
    --image=image-name \
    --scopes=https://www.googleapis.com/auth/cloud-platform

Créer un groupe d'instances à partir de votre modèle d'instance

gcloud compute instance-groups managed create instance-group-name \
    --size 4 \
    --template template-name \
    --zone us-central1-a

Connection SSH à votre VM Compute Engine

Après avoir créé votre groupe d'instances, connectez-vous en SSH à l'une des instances (VM) de votre groupe d'instances. Exécutez la commande suivante pour répertorier toutes les instances de votre groupe d'instances dans la commande gcloud :

gcloud compute instance-groups list-instances instance-group-name

Connectez-vous en SSH à l'une des instances répertoriées à l'aide de la commande list-instances.

gcloud compute ssh instance-name --zone=us-central1-a

Vérifier que les VM de votre groupe d'instances peuvent communiquer entre elles

Utilisez la commande nmap pour vérifier que les VM de votre groupe d'instances peuvent communiquer entre elles. Exécutez la commande nmap à partir de la VM à laquelle vous êtes connecté en remplaçant instance-name par le nom d'instance d'une autre VM de votre groupe d'instances.

(vm)$ nmap -Pn -p 8477 instance-name
Starting Nmap 7.40 ( https://nmap.org ) at 2019-10-02 21:35 UTC
Nmap scan report for pytorch-20190923-n4tx.c.jysohntpu.internal (10.164.0.3)
Host is up (0.00034s latency).
PORT     STATE  SERVICE
8477/tcp closed unknown

Les règles de pare-feu sont correctement définies si le champ STATE n'indique pas filtered (filtré).

Créer un pod Cloud TPU

gcloud compute tpus create tpu-name \
    --zone=us-central1-a \
    --network=default \
    --accelerator-type=v2-32 \
    --version=1.6

Exécuter un entraînement distribué sur le pod

  1. À partir de la fenêtre de session de VM, exportez le nom Cloud TPU et activez l'environnement Conda.

    (vm)$ export TPU_NAME=tpu-name
    (vm)$ conda activate torch-xla-1.6
    
  2. Exécutez le script d'entraînement :

    (torch-xla-1.6)$ python -m torch_xla.distributed.xla_dist \
          --tpu=$TPU_NAME \
          --conda-env=torch-xla-1.6 \
          --env XLA_USE_BF16=1 \
          --env ANY_OTHER=ENV_VAR \
          -- python /usr/share/torch-xla-1.6/pytorch/xla/test/test_train_mp_imagenet.py \
          --fake_data
    

Une fois la commande ci-dessus exécutée, vous devriez voir une sortie semblable à celle-ci (notez que cet exemple utilise --fake_data). L'entraînement dure environ une demi-heure sur un pod TPU v3-32.

2020-08-06 02:38:29  [] Command to distribute: "python" "/usr/share/torch-xla-nightly/pytorch/xla/test/test_train_mp_imagenet.py" "--fake_data"
2020-08-06 02:38:29  [] Cluster configuration: {client_workers: [{10.164.0.43, n1-standard-96, europe-west4-a, my-instance-group-hm88}, {10.164.0.109, n1-standard-96, europe-west4-a, my-instance-group-n3q2}, {10.164.0.46, n1-standard-96, europe-west4-a, my-instance-group-s0xl}, {10.164.0.49, n1-standard-96, europe-west4-a, my-instance-group-zp14}], service_workers: [{10.131.144.61, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.59, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.58, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.60, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}]}
2020-08-06 02:38:31 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:31 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2757      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:34 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:34 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2623      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2583      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2530      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2317      0 --:--:-- --:--:-- --:--:--  2375
2020-08-06 02:38:40 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
2020-08-06 02:38:40 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2748      0 --:--:-- --:--:-- --:--:--  3166
100    19  100    19    0     0   2584      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:40 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2495      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:43 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2654      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:43 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2784      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:43 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2691      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:43 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2589      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/14 Epoch=1 Step=0 Loss=6.87500 Rate=258.47 GlobalRate=258.47 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/15 Epoch=1 Step=0 Loss=6.87500 Rate=149.45 GlobalRate=149.45 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] Epoch 1 train begin 02:38:52
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:1/0 Epoch=1 Step=0 Loss=6.87500 Rate=25.72 GlobalRate=25.72 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.87500 Rate=89.01 GlobalRate=89.01 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/1 Epoch=1 Step=0 Loss=6.87500 Rate=64.15 GlobalRate=64.15 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.87500 Rate=93.19 GlobalRate=93.19 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/7 Epoch=1 Step=0 Loss=6.87500 Rate=58.78 GlobalRate=58.78 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] Epoch 1 train begin 02:38:56
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:1/8 Epoch=1 Step=0 Loss=6.87500 Rate=100.43 GlobalRate=100.43 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/13 Epoch=1 Step=0 Loss=6.87500 Rate=66.83 GlobalRate=66.83 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/11 Epoch=1 Step=0 Loss=6.87500 Rate=64.28 GlobalRate=64.28 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/10 Epoch=1 Step=0 Loss=6.87500 Rate=73.17 GlobalRate=73.17 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/9 Epoch=1 Step=0 Loss=6.87500 Rate=27.29 GlobalRate=27.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/12 Epoch=1 Step=0 Loss=6.87500 Rate=110.29 GlobalRate=110.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/20 Epoch=1 Step=0 Loss=6.87500 Rate=100.85 GlobalRate=100.85 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/22 Epoch=1 Step=0 Loss=6.87500 Rate=93.52 GlobalRate=93.52 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/23 Epoch=1 Step=0 Loss=6.87500 Rate=165.86 GlobalRate=165.86 Time=02:38:57

Nettoyer

Pour éviter que les ressources utilisées dans ce tutoriel soient facturées sur votre compte Google Cloud Platform :

  1. Déconnectez-vous de la VM Compute Engine :

    (vm)$ exit
    
  2. Supprimez votre groupe d'instances :

    gcloud compute instance-groups managed delete instance-group-name
    
  3. Supprimez votre pod TPU :

    gcloud compute tpus delete ${TPU_NAME} --zone=us-central1-a
    

Étape suivante

Essayez les colabs PyTorch :