Ejecuta el código JAX en porciones de pod de TPU

Una vez que tu código JAX se ejecute en una sola placa TPU, puedes escalar verticalmente el código mediante su ejecución en una porción de pod de TPU. Las porciones de pod de TPU son varias placas de TPU conectadas entre sí en conexiones de red dedicadas de alta velocidad. Este documento es una introducción a la ejecución de código JAX en porciones de pod de TPU. Para obtener información más detallada, consulta Cómo usar JAX en entornos de varios procesos y hosts.

Si deseas usar NFS activado para el almacenamiento de datos, debes configurar el Acceso al SO de todas las VM de TPU en la porción de pod. A fin de obtener más información, consulta Usa un NFS para el almacenamiento de datos.

Crea una porción de pod de TPU

Antes de ejecutar los comandos de este documento, asegúrate de haber seguido las instrucciones en Configura una cuenta y un proyecto de Cloud TPU. Ejecuta los siguientes comandos en tu máquina local.

Crea una porción de pod de TPU con el comando gcloud. Por ejemplo, para crear una porción de pod v2-32 usa el siguiente comando:

$ gcloud alpha compute tpus tpu-vm create tpu-name \
  --zone europe-west4-a \
  --accelerator-type v2-32 \
  --version v2-alpha

Instala JAX en la porción de pod

Después de crear la porción de pod de TPU, debes instalar JAX en todos los hosts de la porción de pod de TPU. Puedes instalar JAX en todos los hosts con un solo comando mediante la opción --worker=all:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

Ejecuta el código JAX en el fragmento de Pod

Para ejecutar el código JAX en un fragmento de pod de TPU, debes ejecutar el código en cada fragmento del pod de TPU. Esto significa que debes establecer una conexión SSH con cada host y ejecutar el código JAX en cada host. Con el siguiente código de Python, se muestra cómo ejecutar un cálculo simple de JAX en una porción de pod de TPU con la opción --worker=all del comando de gcloud.

Prepara el código

Necesitas la versión de gcloud >= 344.0.0 (para el comando scp). Usa gcloud --version para verificar la versión de gcloud y ejecuta gcloud components upgrade si es necesario.

Escribe example.py en la máquina local:

   cat > example.py << EOF

   # The following code snippet will be run on all TPU hosts
   import jax

   # The total number of TPU cores in the Pod
   device_count = jax.device_count()

   # The number of TPU cores attached to this host
   local_device_count = jax.local_device_count()

   # The psum is performed over all mapped devices across the Pod
   xs = jax.numpy.ones(jax.local_device_count())
   r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

   # Print from a single host to avoid duplicated output
   if jax.process_index() == 0:
       print('global device count:', jax.device_count())
       print('local device count:', jax.local_device_count())
       print('pmap result:', r)
   EOF

Copia example.py en todas las VM en la porción de pod.

$ gcloud alpha compute tpus tpu-vm scp example.py tpu-name: \
  --worker=all --zone=europe-west4-a

Si es la primera vez que usas el comando scp, es posible que veas un error similar al siguiente:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH agent.
Please run "ssh-add /.../.ssh/google_compute_engine" to add it, and try again.
Ejecuta el comando ssh-add como se muestra en el mensaje de error y vuelve a ejecutarlo para resolver el error.

Ejecuta el código en la porción de pod

Inicia el programa example.py en cada VM:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a --worker=all --command "python3 example.py"

Resultado (producido con una porción de pod v2-32):

global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]

Esta es una manera de ejecutar el código Python de JAX en cada host, pero puedes usar los métodos que desees. Sin importar cómo lo ejecutes, la llamada jax.device_count() anterior quedará pendiente hasta que se llame en cada host del fragmento del pod, porque todos los hosts deben estar presentes para inicializar el entorno de ejecución de TPU.

Limpia

Cuando termines, puedes liberar los recursos de la VM de TPU mediante el comando gcloud:

$ gcloud alpha compute tpus tpu-vm delete tpu-name \
  --zone europe-west4-a