Executar o código JAX em frações do pod de TPU
Antes de executar os comandos neste documento, verifique se você seguiu as instruções em Configurar uma conta e um projeto do Cloud TPU.
Depois de executar o código JAX em uma única placa de TPU, é possível escalonar verticalmente o código executando-o em uma fração do pod de TPU. As frações do pod de TPU são várias placas de TPU conectadas entre si por conexões de rede dedicadas de alta velocidade. Este documento é uma introdução à execução do código JAX nas frações do pod de TPU. Para informações mais detalhadas, consulte Como usar o JAX em ambientes com vários hosts e processos.
Se você quiser usar o NFS montado para armazenamento de dados, defina o Login do SO para todas as VMs de TPU na fração do pod. Para mais informações, consulte Como usar um NFS para armazenar dados.Criar uma fração de pod da Cloud TPU
Crie algumas variáveis de ambiente:
export PROJECT_ID=your-project export ACCELERATOR_TYPE=v5p-32 export ZONE=europe-west4-b export RUNTIME_VERSION=v2-alpha-tpuv5 export TPU_NAME=your-tpu-name
Descrições das variáveis de ambiente
PROJECT_ID
- O ID do Google Cloud projeto.
ACCELERATOR_TYPE
- O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações sobre os tipos de aceleradores compatíveis com cada versão de TPU, consulte Versões de TPU.
ZONE
- A zona em que você planeja criar a Cloud TPU.
RUNTIME_VERSION
- A versão do ambiente de execução do Cloud TPU.
TPU_NAME
- O nome atribuído pelo usuário ao Cloud TPU.
Crie uma fração do pod de TPU usando o comando
gcloud
. Por exemplo, para criar uma fração de pod v5p-32, use o seguinte comando:$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
Instalar o JAX na fração do pod
Depois de criar a fração do pod de TPU, é necessário instalar o JAX em todos os hosts
nessa fração. Para isso, use o comando gcloud compute tpus tpu-vm ssh
com
os parâmetros --worker=all
e --commamnd
.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Executar o código JAX na fração do pod
Para executar o código JAX em uma fatia do pod de TPU, é preciso executar o código em cada host na
fração do pod de TPU. A chamada jax.device_count()
para de responder até ser
chamada em cada host no slice do pod. O exemplo a seguir ilustra como
executar um cálculo JAX em uma fração do pod de TPU.
Preparar o código
Você precisa da versão gcloud
>= 344.0.0 (para o
comando scp
).
Use gcloud --version
para verificar a versão de gcloud
e
execute gcloud components upgrade
, se necessário.
Crie um arquivo chamado example.py
com o seguinte código:
import jax
# The total number of TPU cores in the Pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
Copie example.py
para todas as VMs de worker de TPU na fração do pod.
$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \ --worker=all \ --zone=${ZONE} \ --project=${PROJECT_ID}
Se você não tiver usado o comando scp
antes, talvez receba um
erro semelhante a este:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
Para resolver o erro, execute o comando ssh-add
conforme exibido na
mensagem de erro e execute o comando novamente.
Executar o código na fração do pod
Inicie o programa example.py
em todas as VMs:
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --worker=all \ --command="python3 ./example.py"
Saída (produzida com uma fração do pod v4-32):
global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]
Limpar
Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.
Desconecte-se da instância do Compute Engine, caso ainda não tenha feito isso:
(vm)$ exit
Agora, o prompt precisa ser
username@projectname
, mostrando que você está no Cloud Shell.Exclua os recursos do Cloud TPU e do Compute Engine.
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Execute
gcloud compute tpus execution-groups list
para verificar se os recursos foram excluídos. A exclusão pode levar vários minutos. A saída do comando a seguir não pode incluir nenhum dos recursos criados neste tutorial:$ gcloud compute tpus tpu-vm list --zone=${ZONE} \ --project=${PROJECT_ID}