Executar o código JAX em frações do pod de TPU
Depois de executar o código JAX em uma única placa de TPU, é possível escalonar verticalmente o código executando-o em uma fração do pod de TPU. As frações do pod de TPU são várias placas de TPU conectadas entre si por conexões de rede dedicadas de alta velocidade. Este documento é uma introdução à execução do código JAX nas frações do pod de TPU. Para informações mais detalhadas, consulte Como usar o JAX em ambientes com vários hosts e processos.
Se você quiser usar o NFS montado para armazenamento de dados, defina o Login do SO para todas as VMs de TPU na fração do pod. Para mais informações, consulte Como usar um NFS para armazenar dados.Configurar o ambiente
No Cloud Shell, execute o seguinte comando para garantir que você está executando a versão atual do
gcloud
:$ gcloud components update
Se você precisar instalar
gcloud
, use o seguinte comando:$ sudo apt install -y google-cloud-sdk
Crie algumas variáveis de ambiente:
$ export TPU_NAME=tpu-name $ export ZONE=us-central2-b $ export RUNTIME_VERSION=tpu-ubuntu2204-base $ export ACCELERATOR_TYPE=v4-32
Criar uma fração do pod de TPU
Antes de executar os comandos neste documento, verifique se você seguiu as instruções em Configurar uma conta e um projeto do Cloud TPU. Execute os comandos a seguir na máquina local.
Crie uma fração do pod de TPU usando o comando gcloud
. Por exemplo, para criar uma
fração de pod v4-32, use o seguinte comando:
$ gcloud compute tpus tpu-vm create ${TPU_NAME} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --version=${RUNTIME_VERSION}
Instalar o JAX na fração do pod
Depois de criar a fração do pod de TPU, é necessário instalar o JAX em todos os hosts
nessa fração. É possível instalar o JAX em todos os hosts com um único comando usando a
opção --worker=all
:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --worker=all \ --command="pip install \ --upgrade 'jax[tpu]>0.3.0' \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
Executar o código JAX na fração do pod
Para executar o código JAX em uma fatia do pod de TPU, é preciso executar o código em cada host na
fração do pod de TPU. A chamada jax.device_count()
para de responder até ser
chamada em cada host no slice do pod. O exemplo a seguir ilustra como
executar um cálculo de JAX em uma fração do Pod de TPU.
Preparar o código
Você precisa da versão gcloud
>= 344.0.0 (para o
comando scp
).
Use gcloud --version
para verificar a versão do gcloud
.
execute gcloud components upgrade
, se necessário.
Crie um arquivo chamado example.py
com o seguinte código:
# The following code snippet will be run on all TPU hosts
import jax
# The total number of TPU cores in the Pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
Copie example.py
para todas as VMs de worker de TPU na fração do pod.
$ gcloud compute tpus tpu-vm scp example.py ${TPU_NAME} \ --worker=all \ --zone=${ZONE}
Se você não usou o comando scp
antes, talvez veja uma
semelhante a este:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.
Para resolver o erro, execute o comando ssh-add
conforme exibido na
e execute o comando novamente.
Executar o código na fração do pod
Inicie o programa example.py
em todas as VMs:
$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --zone=${ZONE} \ --worker=all \ --command="python3 example.py"
Saída (produzida com uma fração do pod v4-32):
global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]
Limpar
Quando terminar de usar a VM de TPU, siga estas etapas para limpar os recursos.
Desconecte-se da instância do Compute Engine, caso ainda não tenha feito isso:
(vm)$ exit
Agora, o prompt precisa ser
username@projectname
, mostrando que você está no Cloud Shell.Exclua os recursos do Cloud TPU e do Compute Engine.
$ gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --zone=${ZONE}
Execute
gcloud compute tpus execution-groups list
para verificar se os recursos foram excluídos. A exclusão pode levar vários minutos. A saída do comando a seguir não pode incluir nenhum dos recursos criados neste tutorial:$ gcloud compute tpus tpu-vm list --zone=${ZONE}