Executar o código JAX em frações do pod de TPU

Antes de executar os comandos neste documento, verifique se você seguiu as instruções em Configurar uma conta e um projeto do Cloud TPU.

Depois de executar o código JAX em uma única placa de TPU, é possível escalonar verticalmente o código executando-o em uma fração do pod de TPU. As frações do pod de TPU são várias placas de TPU conectadas entre si por conexões de rede dedicadas de alta velocidade. Este documento é uma introdução à execução do código JAX nas frações do pod de TPU. Para informações mais detalhadas, consulte Como usar o JAX em ambientes com vários hosts e processos.

Se você quiser usar o NFS montado para armazenamento de dados, defina o Login do SO para todas as VMs de TPU na fração do pod. Para mais informações, consulte Como usar um NFS para armazenar dados.

Criar uma fração de pod da Cloud TPU

  1. Crie algumas variáveis de ambiente:

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=europe-west4-b
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    Descrições das variáveis de ambiente

    PROJECT_ID
    O ID do Google Cloud projeto.
    ACCELERATOR_TYPE
    O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores compatíveis com cada versão de TPU, consulte Versões de TPU.
    ZONE
    A zona em que você planeja criar a Cloud TPU.
    RUNTIME_VERSION
    A versão do ambiente de execução do Cloud TPU.
    TPU_NAME
    O nome atribuído pelo usuário ao Cloud TPU.
  2. Crie uma fração do pod de TPU usando o comando gcloud. Por exemplo, para criar uma fração de pod v5p-32, use o seguinte comando:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
    --zone=${ZONE} \
    --project=${PROJECT_ID} \
    --accelerator-type=${ACCELERATOR_TYPE}  \
    --version=${RUNTIME_VERSION} 

Instalar o JAX na fração do pod

Depois de criar a fração do pod de TPU, é necessário instalar o JAX em todos os hosts nessa fração. Para isso, use o comando gcloud compute tpus tpu-vm ssh com os parâmetros --worker=all e --commamnd.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Executar o código JAX na fração do pod

Para executar o código JAX em uma fatia do pod de TPU, é preciso executar o código em cada host na fração do pod de TPU. A chamada jax.device_count() para de responder até ser chamada em cada host no slice do pod. O exemplo a seguir ilustra como executar um cálculo JAX em uma fração do pod de TPU.

Preparar o código

Você precisa da versão gcloud >= 344.0.0 (para o comando scp). Use gcloud --version para verificar a versão de gcloud e execute gcloud components upgrade, se necessário.

Crie um arquivo chamado example.py com o seguinte código:


import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copie example.py para todas as VMs de worker de TPU na fração do pod.

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

Se você não tiver usado o comando scp antes, talvez receba um erro semelhante a este:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Para resolver o erro, execute o comando ssh-add conforme exibido na mensagem de erro e execute o comando novamente.

Executar o código na fração do pod

Inicie o programa example.py em todas as VMs:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

Saída (produzida com uma fração do pod v4-32):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

Limpar

Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.

  1. Desconecte-se da instância do Compute Engine, caso ainda não tenha feito isso:

    (vm)$ exit

    Agora, o prompt precisa ser username@projectname, mostrando que você está no Cloud Shell.

  2. Exclua os recursos do Cloud TPU e do Compute Engine.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  3. Execute gcloud compute tpus execution-groups list para verificar se os recursos foram excluídos. A exclusão pode levar vários minutos. A saída do comando a seguir não pode incluir nenhum dos recursos criados neste tutorial:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}