Ejecuta código JAX en los fragmentos de pod de TPU

Después de ejecutar tu código JAX en una sola placa de TPU, puedes escalarlo mediante su ejecución en una porción de pod de TPU. Los fragmentos de pod de TPU son varios tableros de TPU conectados entre sí en conexiones de red dedicadas de alta velocidad. Este documento es una introducción a la ejecución del código JAX en los fragmentos de pod de TPU. Para obtener información más detallada, consulta Cómo usar JAX en entornos de varios procesos y procesos múltiples.

Crea una porción de pod de TPU

Crea una porción de pod de TPU con el comando gcloud. Por ejemplo, para crear una porción de pod v2-32, usa el siguiente comando:

$ gcloud alpha compute tpus tpu-vm create tpu-name \
  --zone europe-west4-a \
  --accelerator-type v2-32 \
  --version v2-alpha

Instale JAX en la porción del pod

Después de crear la porción de pod de TPU, debes instalar JAX en todos los hosts en la porción de pod de TPU. Puedes instalar JAX en todos los hosts con un solo comando mediante la opción --worker=all:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command="pip install --upgrade jax jaxlib"

Ejecuta código JAX en la porción de pod

Para ejecutar el código JAX en una porción de pod de TPU, debes ejecutar el código en cada host en la porción de pod de TPU. Esto significa que debes establecer una conexión SSH con cada host y ejecutar el código JAX en cada host. En el código siguiente de Python, se ilustra cómo ejecutar un cálculo simple de JAX en una porción de pod de TPU con la opción --worker=all del comando gcloud.

Preparar código

$ read -r -d '' PYTHON_CMD << EOF
# The following code snippet will be run on all TPU hosts
import jax

# The total number of TPU cores in the pod
device_count = jax.device_count()
# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)
EOF

Ejecuta el código de la porción de pod

$ gcloud alpha compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command "python3 -c \"$PYTHON_CMD\""

Resultado (producido con una porción de pod v2-32):

global device count: 32
local device count: 8
pmap result: [32. 32. 32. 32. 32. 32. 32. 32.]

Esta es una forma de ejecutar código Python JAX en cada host, pero puedes usar los métodos que desees. Sin embargo, la ejecución anterior de jax.device_count() se ejecutará hasta que se llame en cada host en la porción de pod, ya que todos los hosts deben estar presentes para inicializar el entorno de ejecución de la TPU.

Limpia

Cuando termines, puedes liberar tus recursos de VM de TPU con el comando gcloud:

$ gcloud alpha compute tpus tpu-vm delete tpu-name \
  --zone europe-west4-a