Executar o código JAX em frações do pod de TPU
Depois de executar o código JAX em uma única placa de TPU, é possível escalonar verticalmente o código executando-o em uma fração do pod de TPU. As frações do pod de TPU são várias placas de TPU conectadas entre si por conexões de rede dedicadas de alta velocidade. Este documento é uma introdução à execução do código JAX nas frações do pod de TPU. Para informações mais detalhadas, consulte Como usar o JAX em ambientes com vários hosts e processos.
Se você quiser usar o NFS ativado para armazenamento de dados, defina o Login do SO para todas as VMs de TPU na fração do pod. Para mais informações, consulte Como usar um NFS para armazenamento de dados.Criar uma fração do pod de TPU
Antes de executar os comandos neste documento, siga as instruções em Configurar uma conta e um projeto do Cloud TPU. Execute os comandos a seguir na máquina local.
Crie uma fração do pod de TPU usando o comando gcloud
. Por exemplo, para criar uma fração de pod v4-32, use o seguinte comando:
$ gcloud compute tpus tpu-vm create tpu-name \
--zone=us-central2-b \
--accelerator-type=v4-32 \
--version=tpu-ubuntu2204-base
Instalar o JAX na fração do pod
Depois de criar a fração do pod de TPU, é necessário instalar o JAX em todos os hosts
nessa fração. É possível instalar o JAX em todos os hosts com um único comando usando a
opção --worker=all
:
gcloud compute tpus tpu-vm ssh tpu-name \ --zone=us-central2-b --worker=all --command="pip install \ --upgrade 'jax[tpu]>0.3.0' \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
Executar o código JAX na fração do pod
Para executar o código JAX em uma fatia do pod de TPU, é preciso executar o código em cada host na
fração do pod de TPU. A chamada jax.device_count()
para de responder até ser chamada em cada host na fração do pod. O exemplo a seguir ilustra como executar um cálculo simples do JAX em uma fração do Pod de TPU.
Preparar o código
Você precisa da versão 344.0.0 ou mais recente do gcloud
(para o
comando scp
).
Use gcloud --version
para verificar a versão do gcloud
e, se necessário, execute gcloud components upgrade
.
Crie um arquivo chamado example.py
com o seguinte código:
# The following code snippet will be run on all TPU hosts
import jax
# The total number of TPU cores in the Pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()
# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)
# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
Copiar example.py
para todas as VMs de worker da TPU na fração do pod
$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
--worker=all \
--zone=us-central2-b
Se você ainda não tiver usado o comando scp
, poderá ver um erro semelhante ao seguinte:
ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try again.
Para resolver o erro, execute o comando ssh-add
conforme exibido na
mensagem de erro e execute o comando novamente.
Executar o código na fração do pod
Inicie o programa example.py
em todas as VMs:
$ gcloud compute tpus tpu-vm ssh tpu-name \
--zone=us-central2-b \
--worker=all \
--command="python3 example.py"
Saída (produzida com uma fração do pod v4-32):
global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]
Limpar
Quando terminar, você poderá liberar os recursos de VM da TPU usando o comando
gcloud
:
$ gcloud compute tpus tpu-vm delete tpu-name \
--zone=us-central2-b