Executar o código JAX em frações do pod de TPU

Depois de executar o código JAX em uma única placa de TPU, é possível escalonar verticalmente o código executando-o em uma fração do pod de TPU. As frações do pod de TPU são várias placas de TPU conectadas entre si por conexões de rede dedicadas de alta velocidade. Este documento é uma introdução à execução do código JAX nas frações do pod de TPU. Para informações mais detalhadas, consulte Como usar o JAX em ambientes com vários hosts e processos.

Se você quiser usar NFS ativado para armazenamento de dados, defina o login do SO para todas as VMs de TPU na fração do pod. Para mais informações, consulte Como usar um NFS para armazenamento de dados.

Criar uma fração do pod de TPU

Antes de executar os comandos neste documento, siga as instruções em Configurar uma conta e um projeto da Cloud TPU. Execute os comandos a seguir na máquina local.

Crie uma fração do pod de TPU usando o comando gcloud. Por exemplo, para criar uma fração do pod v2-32, use o seguinte comando:

$ gcloud compute tpus tpu-vm create tpu-name \
  --zone europe-west4-a \
  --accelerator-type v2-32 \
  --version tpu-vm-base

Instalar o JAX na fração do pod

Depois de criar a fração do pod de TPU, é necessário instalar o JAX em todos os hosts nessa fração. É possível instalar o JAX em todos os hosts com um único comando usando a opção --worker=all:

$ gcloud compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

Executar o código JAX na fração do pod

Para executar o código JAX em uma fatia do pod de TPU, é preciso executar o código em cada host na fração do pod de TPU. A chamada jax.device_count() para de responder até ser chamada em cada host na fração do pod. Veja no exemplo a seguir como executar um cálculo JAX simples em uma fração do pod da TPU.

Preparar o código

Você precisa da versão gcloud >= 344.0.0 (para o comando scp). Use gcloud --version para verificar a versão de gcloud e execute gcloud components upgrade, se necessário.

Grave example.py na máquina local:

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copiar example.py para todas as VMs na fração de pod

$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all --zone=europe-west4-a

Se você nunca usou o comando scp, talvez veja um erro semelhante ao seguinte:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Para resolver o erro, execute o comando ssh-add conforme exibido na mensagem de erro e execute o comando novamente.

Executar o código na fração do pod

Inicie o programa example.py em todas as VMs:

$ gcloud compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a --worker=all --command "python3 example.py"

Saída (produzida com uma fração do pod v2-32):

global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]

Limpar

Quando terminar, você poderá liberar os recursos de VM da TPU usando o comando gcloud:

$ gcloud compute tpus tpu-vm delete tpu-name \
  --zone europe-west4-a