Execute código JAX em fatias de TPUs

Antes de executar os comandos neste documento, certifique-se de que seguiu as instruções em Configure uma conta e um projeto do Cloud TPU.

Depois de ter o código JAX em execução numa única placa de TPU, pode dimensionar o código executando-o numa divisão de TPU. As fatias de TPU são várias placas de TPU ligadas entre si através de ligações de rede de alta velocidade dedicadas. Este documento é uma introdução à execução de código JAX em fatias de TPU. Para obter informações mais detalhadas, consulte o artigo Usar o JAX em ambientes com vários anfitriões e vários processos.

Crie uma fatia da Cloud TPU

  1. Crie algumas variáveis de ambiente:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5litepod-32
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    Descrições das variáveis de ambiente

    Variável Descrição
    PROJECT_ID O seu Google Cloud ID do projeto. Use um projeto existente ou crie um novo.
    TPU_NAME O nome da TPU.
    ZONE A zona na qual criar a VM da TPU. Para mais informações sobre as zonas suportadas, consulte o artigo Regiões e zonas de TPUs.
    ACCELERATOR_TYPE O tipo de acelerador especifica a versão e o tamanho do Cloud TPU que quer criar. Para mais informações sobre os tipos de aceleradores suportados para cada versão da TPU, consulte o artigo Versões da TPU.
    RUNTIME_VERSION A versão do software do Cloud TPU.

  2. Crie uma fatia de TPU com o comando gcloud. Por exemplo, para criar uma fatia v5litepod-32, use o seguinte comando:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --accelerator-type=${ACCELERATOR_TYPE}  \
        --version=${RUNTIME_VERSION} 

Instale o JAX na sua fatia

Depois de criar a fatia de TPU, tem de instalar o JAX em todos os anfitriões na fatia de TPU. Pode fazê-lo através do comando gcloud compute tpus tpu-vm ssh com os parâmetros --worker=all e --commamnd.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Execute código JAX na fatia

Para executar código JAX numa fatia de TPU, tem de executar o código em cada anfitrião na fatia de TPU. A chamada jax.device_count() deixa de responder até ser chamada em cada anfitrião na fatia. O exemplo seguinte ilustra como executar um cálculo JAX numa fatia de TPU.

Prepare o código

Precisa da versão gcloud >= 344.0.0 (para o comando scp). Use gcloud --version para verificar a versão do gcloud e execute gcloud components upgrade, se necessário.

Crie um ficheiro denominado example.py com o seguinte código:


import jax

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copie example.py para todas as VMs de trabalho da TPU na fatia

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

Se não tiver usado o comando scp anteriormente, pode ver um erro semelhante ao seguinte:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Para resolver o erro, execute o comando ssh-add, conforme apresentado na mensagem de erro, e volte a executar o comando.

Execute o código na fatia

Inicie o programa example.py em todas as VMs:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

Saída (produzida com uma fatia v5litepod-32):

global device count: 32
local device count: 4
pmap result: [32. 32. 32. 32.]

Limpar

Quando terminar de usar a VM de TPU, siga estes passos para limpar os recursos.

  1. Elimine os recursos do Cloud TPU e do Compute Engine.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  2. Verifique se os recursos foram eliminados executando gcloud compute tpus execution-groups list. A eliminação pode demorar vários minutos. O resultado do seguinte comando não deve incluir nenhum dos recursos criados neste tutorial:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}