Ejecuta el código JAX en porciones de pod de TPU

Una vez que tu código JAX se ejecute en un único panel de TPU, puedes escalarlo de forma vertical en una porción de pod de TPU. Las porciones de pod de TPU son varios paneles de TPU conectados entre sí en conexiones de red dedicadas de alta velocidad. Este documento es una introducción a la ejecución de código JAX en porciones de pod de TPU. Para obtener información más detallada, consulta Usa JAX en entornos de hosts múltiples y de procesos múltiples.

Si deseas usar NFS activado para el almacenamiento de datos, debes configurar el Acceso al SO para todas las VM de TPU en la porción del Pod. Si quieres obtener más información, consulta Usa un NFS para el almacenamiento de datos.

Crea una porción de pod de TPU

Antes de ejecutar los comandos de este documento, asegúrate de haber seguido las instrucciones en Configura una cuenta y un proyecto de Cloud TPU. Ejecuta los siguientes comandos en tu máquina local.

Crea una porción de pod de TPU con el comando gcloud. Por ejemplo, para crear una porción de Pod v4-32, usa el siguiente comando:

$ gcloud compute tpus tpu-vm create tpu-name  \
  --zone=us-central2-b \
  --accelerator-type=v4-32  \
  --version=tpu-ubuntu2204-base 

Instala JAX en la porción de pod

Después de crear la porción de pod de TPU, debes instalar JAX en todos los hosts en la porción de pod de TPU. Puedes instalar JAX en todos los hosts con un solo comando mediante la opción --worker=all:

  gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b --worker=all --command="pip install \
  --upgrade 'jax[tpu]>0.3.0' \
  -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

Ejecuta el código JAX en la porción de pod

Para ejecutar el código JAX en una porción de pod de TPU, debes ejecutar el código en cada host en la porción de pod de TPU. La llamada jax.device_count() deja de responder hasta que se la llama en cada host en la porción de Pod. En el siguiente ejemplo, se muestra cómo ejecutar un cálculo simple de JAX en una porción de pod de TPU.

Prepara el código

Necesitas la versión de gcloud >= 344.0.0 (para el comando scp). Usa gcloud --version para verificar tu versión de gcloud y ejecuta gcloud components upgrade si es necesario.

Crea un archivo llamado example.py con el siguiente código:

# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the Pod
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copia example.py en todas las VM de trabajador TPU en la porción de Pod

$ gcloud compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all \
  --zone=us-central2-b

Si no usaste el comando scp antes, es posible que veas un error similar al siguiente:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Para resolver el error, ejecuta el comando ssh-add como se muestra en el mensaje de error y vuelve a ejecutarlo.

Ejecuta el código en la porción de pod

Inicia el programa example.py en cada VM:

$ gcloud compute tpus tpu-vm ssh tpu-name \
  --zone=us-central2-b \
  --worker=all \
  --command="python3 example.py"

Resultado (producido con una porción de Pod v4-32):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

Limpia

Cuando termines, puedes liberar los recursos de VM de TPU con el comando gcloud:

$ gcloud compute tpus tpu-vm delete tpu-name \
  --zone=us-central2-b