Como treinar modelos PyTorch em pods do Cloud TPU

Neste tutorial, você verá como escalonar verticalmente o treinamento do modelo a partir de uma única Cloud TPU (v2-8 ou v3-8) para um pod do Cloud TPU. Os aceleradores do Cloud TPU em um pod de TPU são conectados por interconexões de alta largura de banda. Isso os torna eficientes para escalonar verticalmente jobs de treinamento.

Para mais informações sobre as ofertas de pods da Cloud TPU, consulte a página do produto ou a apresentação da Cloud TPU (em inglês).

O diagrama a seguir fornece uma visão geral da configuração do cluster distribuído. Um grupo de instâncias de VMs está conectado a um Pod de TPU. É necessária uma VM para cada grupo de oito núcleos de TPU. As VMs alimentam os dados para os núcleos da TPU e todos os treinamentos ocorrem no Pod de TPU.

image

Objetivos

  • Configurar um grupo de instâncias do Compute Engine e um pod da Cloud TPU para treinamento com PyTorch/XLA
  • Executar treinamento PyTorch/XLA em um pod da Cloud TPU

Antes de começar

Antes de iniciar o treinamento distribuído nos Pods do Cloud TPU, verifique se o modelo treina em um único dispositivo do Cloud TPU v2-8 ou v3-8. Caso o modelo tenha problemas significativos de desempenho em um único dispositivo, consulte os guias de práticas recomendadas e de solução de problemas (links em inglês).

Depois de treinar o único dispositivo TPU, execute os passos a seguir para configurar e treinar em um Pod do Cloud TPU:

  1. Configure o comando gcloud.

  2. [Opcional] Capture uma imagem de disco da VM em uma imagem de VM.

  3. Crie um modelo de instância a partir de uma imagem de VM.

  4. Crie um grupo de instâncias a partir do seu modelo de instância.

  5. SSH na VM do Compute Engine

  6. Verifique as regras de firewall para permitir a comunicação entre VMs..

  7. Crie um pod da Cloud TPU.

  8. Execute treinamento distribuído no pod.

  9. Faça a limpeza.

Configurar o comando gcloud

Configure seu projeto do GCP com gcloud:

Crie uma variável para o ID do seu projeto.

export PROJECT_ID=project-id

Defina o ID do projeto como o padrão em gcloud

gcloud config set project ${PROJECT_ID}

Na primeira vez que você executar esse comando em uma nova VM do Cloud Shell, uma página Authorize Cloud Shell será exibida. Clique em Authorize na parte inferior da página para permitir que gcloud faça chamadas de API do GCP com suas credenciais.

Configure a zona padrão com gcloud:

gcloud config set compute/zone us-central1-a

[Opcional] Capturar uma imagem do disco da VM

É possível usar a imagem do disco da VM usada para treinar a única TPU (que já tem o conjunto de dados, os pacotes instalados etc.) Antes de criar uma imagem, interrompa a VM usando o comando gcloud:

gcloud compute instances stop vm-name

Em seguida, crie uma imagem de VM usando o comando gcloud:

gcloud compute images create image-name  \
    --source-disk instance-name \
    --source-disk-zone us-central1-a \
    --family=torch-xla \
    --storage-location us-central1

Crie um modelo de instância a partir de uma imagem de VM

Criar um modelo de instância padrão. Ao criar um modelo de instância, use a imagem de VM criada na etapa acima OU use a imagem pública PyTorch/XLA fornecida pelo Google. Para criar um modelo de instância, use o comando gcloud:

gcloud compute instance-templates create instance-template-name \
    --machine-type n1-standard-16 \
    --image-project=${PROJECT_ID} \
    --image=image-name \
    --scopes=https://www.googleapis.com/auth/cloud-platform

Crie um grupo de instâncias a partir do modelo de instância

gcloud compute instance-groups managed create instance-group-name \
    --size 4 \
    --template template-name \
    --zone us-central1-a

SSH na VM do Compute Engine

Depois de criar o grupo de instâncias, use o SSH em uma das instâncias (VMs) no grupo. Use o comando a seguir para listar todas as instâncias da instância agrupando o comando gcloud:

gcloud compute instance-groups list-instances instance-group-name

Conecte-se a uma das instâncias listadas no comando list-instances.

gcloud compute ssh instance-name --zone=us-central1-a

Verificar se as VMs no grupo de instâncias podem se comunicar umas com as outras

Use o comando nmap para verificar se as VMs no grupo de instâncias podem se comunicar umas com as outras. Execute o comando nmap na VM a que você está conectado, substituindo instance-name pelo nome da instância de outra VM no grupo de instâncias.

(vm)$ nmap -Pn -p 8477 instance-name
Starting Nmap 7.40 ( https://nmap.org ) at 2019-10-02 21:35 UTC
Nmap scan report for pytorch-20190923-n4tx.c.jysohntpu.internal (10.164.0.3)
Host is up (0.00034s latency).
PORT     STATE  SERVICE
8477/tcp closed unknown

Contanto que o campo ESTADO não exiba filtrado, as regras de firewall estão configuradas corretamente.

Criar um pod da Cloud TPU

gcloud compute tpus create tpu-name \
    --zone=us-central1-a \
    --network=default \
    --accelerator-type=v2-32 \
    --version=1.7

Executar treinamentos distribuídos no pod

  1. Na janela de sessão da VM, exporte o nome da Cloud TPU e ative o ambiente conda.

    (vm)$ export TPU_NAME=tpu-name
    (vm)$ conda activate torch-xla-1.7
    
  2. Execute o script de treinamento:

    (torch-xla-1.7)$ python -m torch_xla.distributed.xla_dist \
          --tpu=$TPU_NAME \
          --conda-env=torch-xla-1.7 \
          --env XLA_USE_BF16=1 \
          --env ANY_OTHER=ENV_VAR \
          -- python /usr/share/torch-xla-1.7/pytorch/xla/test/test_train_mp_imagenet.py \
          --fake_data
    

Depois de executar o comando acima, haverá saída semelhante a esta. Observe que o --fake_data está sendo usado. O treinamento leva cerca de meia hora em um pod de TPU v3-32.

2020-08-06 02:38:29  [] Command to distribute: "python" "/usr/share/torch-xla-nightly/pytorch/xla/test/test_train_mp_imagenet.py" "--fake_data"
2020-08-06 02:38:29  [] Cluster configuration: {client_workers: [{10.164.0.43, n1-standard-96, europe-west4-a, my-instance-group-hm88}, {10.164.0.109, n1-standard-96, europe-west4-a, my-instance-group-n3q2}, {10.164.0.46, n1-standard-96, europe-west4-a, my-instance-group-s0xl}, {10.164.0.49, n1-standard-96, europe-west4-a, my-instance-group-zp14}], service_workers: [{10.131.144.61, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.59, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.58, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}, {10.131.144.60, 8470, v3-32, europe-west4-a, pytorch-nightly, my-tpu-slice}]}
2020-08-06 02:38:31 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:31 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2757      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:34 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:34 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2623      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2583      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2530      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:37 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:37 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2317      0 --:--:-- --:--:-- --:--:--  2375
2020-08-06 02:38:40 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
2020-08-06 02:38:40 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2748      0 --:--:-- --:--:-- --:--:--  3166
100    19  100    19    0     0   2584      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:40 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:40 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2495      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:43 10.164.0.49 [3]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.49 [3]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2654      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:43 10.164.0.43 [0]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.43 [0]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2784      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:43 10.164.0.46 [2]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.46 [2]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2691      0 --:--:-- --:--:-- --:--:--  3166
2020-08-06 02:38:43 10.164.0.109 [1]   % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
2020-08-06 02:38:43 10.164.0.109 [1]                                  Dload  Upload   Total   Spent    Left  Speed
100    19  100    19    0     0   2589      0 --:--:-- --:--:-- --:--:--  2714
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/14 Epoch=1 Step=0 Loss=6.87500 Rate=258.47 GlobalRate=258.47 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/15 Epoch=1 Step=0 Loss=6.87500 Rate=149.45 GlobalRate=149.45 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] Epoch 1 train begin 02:38:52
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:1/0 Epoch=1 Step=0 Loss=6.87500 Rate=25.72 GlobalRate=25.72 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/6 Epoch=1 Step=0 Loss=6.87500 Rate=89.01 GlobalRate=89.01 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/1 Epoch=1 Step=0 Loss=6.87500 Rate=64.15 GlobalRate=64.15 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/2 Epoch=1 Step=0 Loss=6.87500 Rate=93.19 GlobalRate=93.19 Time=02:38:57
2020-08-06 02:38:57 10.164.0.43 [0] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.43 [0] | Training Device=xla:0/7 Epoch=1 Step=0 Loss=6.87500 Rate=58.78 GlobalRate=58.78 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] Epoch 1 train begin 02:38:56
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:1/8 Epoch=1 Step=0 Loss=6.87500 Rate=100.43 GlobalRate=100.43 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/13 Epoch=1 Step=0 Loss=6.87500 Rate=66.83 GlobalRate=66.83 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/11 Epoch=1 Step=0 Loss=6.87500 Rate=64.28 GlobalRate=64.28 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/10 Epoch=1 Step=0 Loss=6.87500 Rate=73.17 GlobalRate=73.17 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/9 Epoch=1 Step=0 Loss=6.87500 Rate=27.29 GlobalRate=27.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.109 [1] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.109 [1] | Training Device=xla:0/12 Epoch=1 Step=0 Loss=6.87500 Rate=110.29 GlobalRate=110.29 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/20 Epoch=1 Step=0 Loss=6.87500 Rate=100.85 GlobalRate=100.85 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/22 Epoch=1 Step=0 Loss=6.87500 Rate=93.52 GlobalRate=93.52 Time=02:38:57
2020-08-06 02:38:57 10.164.0.46 [2] ==> Preparing data..
2020-08-06 02:38:57 10.164.0.46 [2] | Training Device=xla:0/23 Epoch=1 Step=0 Loss=6.87500 Rate=165.86 GlobalRate=165.86 Time=02:38:57

Limpar

Para evitar cobranças dos recursos usados neste tutorial na conta do Google Cloud Platform:

  1. Encerre a conexão com a VM do Compute Engine:

    (vm)$ exit
    
  2. Exclua o grupo de instâncias:

    gcloud compute instance-groups managed delete instance-group-name
    
  3. Exclua o Pod de TPU:

    gcloud compute tpus delete ${TPU_NAME} --zone=us-central1-a
    

A seguir

Teste as colabs do PyTorch: