Executar o código JAX em frações do pod de TPU

Depois de executar o código JAX em uma única placa de TPU, é possível escalonar verticalmente o código executando-o em uma fração do pod de TPU. As frações do pod de TPU são várias placas de TPU conectadas entre si por conexões de rede dedicadas de alta velocidade. Este documento é uma introdução à execução do código JAX nas frações do pod de TPU. Para informações mais detalhadas, consulte Como usar o JAX em ambientes com vários hosts e processos.

Se você quiser usar o NFS ativado para armazenamento de dados, defina o Login do SO para todas as VMs de TPU na fração do pod. Para mais informações, consulte Como usar um NFS para armazenamento de dados.

Criar uma fração do pod de TPU

Antes de executar os comandos neste documento, siga as instruções em Configurar uma conta e um projeto do Cloud TPU. Execute os comandos a seguir na máquina local.

Crie uma fração do pod de TPU usando o comando gcloud. Por exemplo, para criar uma fração de pod v4-32, use o seguinte comando:

$ gcloud compute tpus tpu-vm create tpu-name  \
  --zone=us-central2-b \
  --accelerator-type=v4-32  \
  --version=tpu-ubuntu2204-base 

Instalar o JAX na fração do pod

Depois de criar a fração do pod de TPU, é necessário instalar o JAX em todos os hosts nessa fração. É possível instalar o JAX em todos os hosts com um único comando usando a opção --worker=all:

  gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b --worker=all --command="pip install \
  --upgrade 'jax[tpu]>0.3.0' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

Executar o código JAX na fração do pod

Para executar o código JAX em uma fatia do pod de TPU, é preciso executar o código em cada host na fração do pod de TPU. A chamada jax.device_count() para de responder até ser chamada em cada host na fração do pod. O exemplo a seguir ilustra como executar um cálculo simples do JAX em uma fração do Pod de TPU.

Preparar o código

Você precisa da versão 344.0.0 ou mais recente do gcloud (para o comando scp). Use gcloud --version para verificar a versão do gcloud e, se necessário, execute gcloud components upgrade.

Crie um arquivo chamado example.py com o seguinte código:

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copiar example.py para todas as VMs de worker da TPU na fração do pod

$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all \
  --zone=us-central2-b

Se você ainda não tiver usado o comando scp, poderá ver um erro semelhante ao seguinte:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Para resolver o erro, execute o comando ssh-add conforme exibido na mensagem de erro e execute o comando novamente.

Executar o código na fração do pod

Inicie o programa example.py em todas as VMs:

$ gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b \
  --worker=all \
  --command="python3 example.py"

Saída (produzida com uma fração do pod v4-32):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

Limpar

Quando terminar, você poderá liberar os recursos de VM da TPU usando o comando gcloud:

$ gcloud compute tpus tpu-vm delete tpu-name \
  --zone=us-central2-b