Executar o código JAX em fatias de TPU

Antes de executar os comandos neste documento, verifique se você seguiu as instruções em Configurar uma conta e um projeto do Cloud TPU.

Depois de executar o código JAX em uma única placa de TPU, é possível escalonar verticalmente o código executando-o em uma fração de TPU. As frações de TPU são várias placas de TPU conectadas entre si por conexões de rede dedicadas de alta velocidade. Este documento é uma introdução à execução do código JAX nas frações de TPU. Para informações mais detalhadas, consulte Como usar o JAX em ambientes com vários hosts e processos.

Criar uma fração da Cloud TPU

  1. Crie algumas variáveis de ambiente:

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=europe-west4-b
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    Descrições das variáveis de ambiente

    PROJECT_ID
    O ID do Google Cloud projeto.
    ACCELERATOR_TYPE
    O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores compatíveis para cada versão de TPU, consulte Versões de TPU.
    ZONE
    A zona em que você planeja criar a Cloud TPU.
    RUNTIME_VERSION
    A versão do ambiente de execução do Cloud TPU.
    TPU_NAME
    O nome atribuído pelo usuário ao Cloud TPU.
  2. Crie uma fração de TPU usando o comando gcloud. Por exemplo, para criar uma fração v5p-32, use o seguinte comando:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
    --zone=${ZONE} \
    --project=${PROJECT_ID} \
    --accelerator-type=${ACCELERATOR_TYPE}  \
    --version=${RUNTIME_VERSION} 

Instalar o JAX na fração

Depois de criar a fração de TPU, é necessário instalar o JAX em todos os hosts nessa fração. Para isso, use o comando gcloud compute tpus tpu-vm ssh com os parâmetros --worker=all e --commamnd.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Executar o código JAX na fração

Para executar o código JAX em uma fatia de TPU, é preciso executar o código em cada host na fatia de TPU. A chamada jax.device_count() para de responder até ser chamada em cada host na fração. O exemplo a seguir ilustra como executar um cálculo JAX em uma fatia de TPU.

Preparar o código

Você precisa da versão gcloud >= 344.0.0 (para o comando scp). Use gcloud --version para verificar a versão de gcloud e execute gcloud components upgrade, se necessário.

Crie um arquivo chamado example.py com o seguinte código:


import jax

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copiar example.py para todas as VMs de worker de TPU na fração

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

Se você não tiver usado o comando scp antes, talvez receba um erro semelhante a este:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Para resolver o erro, execute o comando ssh-add conforme exibido na mensagem de erro e execute-o novamente.

Executar o código na fração

Inicie o programa example.py em todas as VMs:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

Saída (produzida com uma fatia v4-32):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

Limpar

Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.

  1. Desconecte-se da instância do Compute Engine, caso ainda não tenha feito isso:

    (vm)$ exit

    Agora, o prompt precisa ser username@projectname, mostrando que você está no Cloud Shell.

  2. Exclua os recursos do Cloud TPU e do Compute Engine.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  3. Execute gcloud compute tpus execution-groups list para verificar se os recursos foram excluídos. A exclusão pode levar vários minutos. A saída do comando a seguir não pode incluir nenhum dos recursos criados neste tutorial:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}