Entrena modelos de PyTorch en pods de Cloud TPU


En este instructivo, se muestra cómo escalar de forma vertical el entrenamiento de tu modelo desde una sola Cloud TPU (v2-8 o v3-8) a un pod de Cloud TPU con la configuración de nodo TPU. Los aceleradores de Cloud TPU en un pod de TPU están conectados por interconexiones de ancho de banda alto, por lo que son eficientes para escalar verticalmente los trabajos de entrenamiento.

Para obtener más información sobre la oferta de pods de Cloud TPU, consulta la página de productos de Cloud TPU o esta presentación de Cloud TPU.

En el siguiente diagrama, se proporciona una descripción general de la configuración del clúster distribuido. Un grupo de instancias de VM está conectado a un pod de TPU. Se necesita una VM para cada grupo de 8 núcleos de TPU. Las VM envían datos a los núcleos de la TPU y todo el entrenamiento se produce en el pod de TPU.

imagen

Objetivos

  • Configura un grupo de instancias de Compute Engine y el pod de Cloud TPU para entrenar con PyTorch/XLA
  • Ejecuta el entrenamiento PyTorch/XLA en un pod de Cloud TPU.

Antes de comenzar

Antes de comenzar el entrenamiento distribuido en los pods de Cloud TPU, verifica que tu modelo entrene en un solo dispositivo de Cloud TPU v2-8 o v3-8. Si tu modelo tiene problemas de rendimiento significativos en un solo dispositivo, consulta las guías sobre prácticas recomendadas y solución de problemas.

Una vez que tu único dispositivo de TPU se haya entrenado con éxito, realiza los siguientes pasos para configurar y entrenar en un pod de Cloud TPU:

  1. Configura el comando gcloud.

  2. Captura una imagen de disco de VM en una imagen de VM (opcional).

  3. Crea una plantilla de instancias a partir de una imagen de VM.

  4. Crea un grupo de instancias a partir de tu plantilla de instancias.

  5. Establece una conexión SSH en tu VM de Compute Engine.

  6. Verifica las reglas del firewall para permitir la comunicación entre VM.

  7. Crea un pod de Cloud TPU.

  8. Ejecuta el entrenamiento distribuido en el pod.

  9. Realiza una limpieza.

Configura el comando gcloud

Configura tu proyecto de Google Cloud con gcloud:

Crea una variable para el ID de tu proyecto.

export PROJECT_ID=project-id

Configura el ID del proyecto como el proyecto predeterminado en gcloud

gcloud config set project ${PROJECT_ID}

La primera vez que ejecutes este comando en una VM de Cloud Shell nueva, se mostrará la página Authorize Cloud Shell. Haz clic en Authorize en la parte inferior de la página para permitir que gcloud realice llamadas a la API con tus credenciales.

Configura la zona predeterminada con gcloud:

gcloud config set compute/zone europe-west4-a

Captura una imagen de disco de VM (opcional)

Puedes usar la imagen de disco de la VM que usaste para entrenar la única TPU (que ya tiene el conjunto de datos, los paquetes instalados, etc.). Antes de crear una imagen, detén la VM con el comando gcloud:

gcloud compute instances stop vm-name

A continuación, crea una imagen de VM con el comando gcloud:

gcloud compute images create image-name  \
    --source-disk instance-name \
    --source-disk-zone europe-west4-a \
    --family=torch-xla \
    --storage-location europe-west4

Crea una plantilla de instancias a partir de una imagen de VM

Crea una plantilla de instancias predeterminada. Cuando estás creando una plantilla de instancias, puedes usar la imagen de VM que creaste en el paso anterior O puedes usar la imagen pública de PyTorch/XLA que proporciona Google. Para crear una plantilla de instancias, usa el comando gcloud:

gcloud compute instance-templates create instance-template-name \
    --machine-type n1-standard-16 \
    --image-project=${PROJECT_ID} \
    --image=image-name \
    --scopes=https://www.googleapis.com/auth/cloud-platform

Crea un grupo de instancias a partir de tu plantilla de instancias

gcloud compute instance-groups managed create instance-group-name \
    --size 4 \
    --template template-name \
    --zone europe-west4-a

Establece una conexión SSH en tu VM de Compute Engine

Después de crear tu grupo de instancias, establece una conexión SSH en una de las instancias (VM) de tu grupo. Usa el siguiente comando para enumerar todas las instancias en tu grupo de instancias y el comando gcloud:

gcloud compute instance-groups list-instances instance-group-name

Establece una conexión SSH con una de las instancias enumeradas desde el comando list-instances.

gcloud compute ssh instance-name --zone=europe-west4-a

Verifica que las VM de tu grupo de instancias puedan comunicarse entre sí

Usa el comando nmap para verificar que las VM de tu grupo de instancias puedan comunicarse entre sí. Ejecuta el comando nmap desde la VM a la que estás conectado y reemplaza instance-name por el nombre de instancia de otra VM en tu grupo de instancias.

(vm)$ nmap -Pn -p 8477 instance-name
Starting Nmap 7.40 ( https://nmap.org ) at 2019-10-02 21:35 UTC
Nmap scan report for pytorch-20190923-n4tx.c.jysohntpu.internal (10.164.0.3)
Host is up (0.00034s latency).
PORT     STATE  SERVICE
8477/tcp closed unknown

Mientras el campo STATE no diga filtrado, las reglas del firewall están configuradas de forma correcta.

Crea un pod de Cloud TPU

gcloud compute tpus create tpu-name \
    --zone=europe-west4-a \
    --network=default \
    --accelerator-type=v2-32 \
    --version=pytorch-1.13

Ejecuta el entrenamiento distribuido en el pod

  1. En la ventana de sesión de VM, exporta el nombre de Cloud TPU y activa el entorno conda.

    (vm)$ export TPU_NAME=tpu-name
    (vm)$ conda activate torch-xla-1.13
    
  2. Ejecuta la siguiente secuencia de comandos de entrenamiento:

    (torch-xla-1.13)$ python -m torch_xla.distributed.xla_dist \
          --tpu=$TPU_NAME \
          --conda-env=torch-xla-1.13 \
          --env XLA_USE_BF16=1 \
          --env ANY_OTHER=ENV_VAR \
          -- python /usr/share/torch-xla-1.13/pytorch/xla/test/test_train_mp_imagenet.py \
          --fake_data
    

Una vez que ejecutes el comando anterior, deberías ver un resultado similar al siguiente (ten en cuenta que esto es mediante --fake_data). El entrenamiento tarda alrededor de media hora en un pod de TPU v3-32.

2020-08-06 02:38:29  [] Command to distribute: "python" "/usr/share/torch-xla-nightly/pytorch/xla/test/test_train_mp_imagenet.py" "--fake_data"
2020-08-06 02:38:29  [] Cluster configuration: {client_workers: [{10.164.0.43, n1-standard-96, europe-west4-a, my-instance-group-hm88}, {10.164.0.109, n1-standard-96, europe-west4-a, my-instance-group-n3q2}, {10.164.0.46, n1-standard-96, europe-west4-a, my-instance-group-s0xl}, {10.164.0.49, n1-standard-96, europe-west4-a, my-instance-group-zp14}], service_workers: [{10.131.144.61, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.59, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.58, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.60, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}]}
2020-08-06 02:38:31 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:31 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2757      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:34 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:34 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2623      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2583      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2530      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2317      0 --:--:-- --:--:-- --:--:--  2375
2020-08-06 02:38:40 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
2020-08-06 02:38:40 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2748      0 --:--:-- --:--:-- --:--:--  3166
100    19  100    19    0     0   2584      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:40 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2495      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:43 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2654      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:43 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2784      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:43 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2691      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:43 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2589      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/14 Epoch=1 Step=0 Loss=6.87500 Rate=258.47 GlobalRate=258.47 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/15 Epoch=1 Step=0 Loss=6.87500 Rate=149.45 GlobalRate=149.45 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] Epoch 1 train begin 02:38:52
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:1/0 Epoch=1 Step=0 Loss=6.87500 Rate=25.72 GlobalRate=25.72 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.87500 Rate=89.01 GlobalRate=89.01 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/1 Epoch=1 Step=0 Loss=6.87500 Rate=64.15 GlobalRate=64.15 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.87500 Rate=93.19 GlobalRate=93.19 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/7 Epoch=1 Step=0 Loss=6.87500 Rate=58.78 GlobalRate=58.78 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] Epoch 1 train begin 02:38:56
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:1/8 Epoch=1 Step=0 Loss=6.87500 Rate=100.43 GlobalRate=100.43 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/13 Epoch=1 Step=0 Loss=6.87500 Rate=66.83 GlobalRate=66.83 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/11 Epoch=1 Step=0 Loss=6.87500 Rate=64.28 GlobalRate=64.28 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/10 Epoch=1 Step=0 Loss=6.87500 Rate=73.17 GlobalRate=73.17 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/9 Epoch=1 Step=0 Loss=6.87500 Rate=27.29 GlobalRate=27.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/12 Epoch=1 Step=0 Loss=6.87500 Rate=110.29 GlobalRate=110.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/20 Epoch=1 Step=0 Loss=6.87500 Rate=100.85 GlobalRate=100.85 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/22 Epoch=1 Step=0 Loss=6.87500 Rate=93.52 GlobalRate=93.52 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/23 Epoch=1 Step=0 Loss=6.87500 Rate=165.86 GlobalRate=165.86 Time=02:38:57

Realiza una limpieza

Para evitar que se apliquen cargos a tu cuenta de Google Cloud por los recursos usados en este instructivo, borra el proyecto que contiene los recursos o conserva el proyecto y borra los recursos individuales.

  1. Desconéctate de la VM de Compute Engine:

    (vm)$ exit
    
  2. Borra tu grupo de instancias:

    gcloud compute instance-groups managed delete instance-group-name
    
  3. Borra el pod de TPU:

    gcloud compute tpus delete ${TPU_NAME} --zone=europe-west4-a
    
  4. Borra la plantilla del grupo de instancias:

    gcloud compute instance-templates delete instance-template-name
    
  5. Borra la imagen de disco de VM (opcional):

    gcloud compute images delete image-name
    

¿Qué sigue?

Prueba los siguientes colaboradores de PyTorch: