Ejecuta el código JAX en porciones de TPU

Antes de ejecutar los comandos de este documento, asegúrate de haber seguido las instrucciones que se indican en Configura una cuenta y un proyecto de Cloud TPU.

Una vez que tu código JAX se ejecute en una sola placa de TPU, puedes escalarlo de forma vertical en una porción de TPU. Las porciones de TPU son varias placas de TPU conectadas entre sí en conexiones de red dedicadas de alta velocidad. Este documento es una introducción a la ejecución de código JAX en porciones de TPU. Para obtener información más detallada, consulta Usa JAX en entornos de hosts múltiples y de procesos múltiples.

Si deseas usar NFS activado para el almacenamiento de datos, debes configurar el acceso al SO para todas las VM de TPU en la porción. Si deseas obtener más información, consulta Usa un NFS para el almacenamiento de datos.

Crea una porción de Cloud TPU

  1. Crea algunas variables de entorno:

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-32
    export ZONE=europe-west4-b
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    Descripciones de las variables de entorno

    PROJECT_ID
    El ID Google Cloud de tu proyecto.
    ACCELERATOR_TYPE
    El tipo de acelerador especifica la versión y el tamaño de la Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    ZONE
    Es la zona en la que deseas crear la Cloud TPU.
    RUNTIME_VERSION
    La versión del entorno de ejecución de Cloud TPU.
    TPU_NAME
    El nombre asignado por el usuario a tu Cloud TPU.
  2. Crea una porción de TPU con el comando gcloud. Por ejemplo, para crear una porción v5p-32, usa el siguiente comando:

    $ gcloud compute tpus tpu-vm create ${TPU_NAME}  \
    --zone=${ZONE} \
    --project=${PROJECT_ID} \
    --accelerator-type=${ACCELERATOR_TYPE}  \
    --version=${RUNTIME_VERSION} 

Instala JAX en tu porción

Después de crear la porción de TPU, debes instalar JAX en todos los hosts en la porción de TPU. Puedes hacerlo con el comando gcloud compute tpus tpu-vm ssh usando los parámetros --worker=all y --commamnd.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Ejecuta el código JAX en la porción

Para ejecutar el código JAX en una porción de TPU, debes ejecutar el código en cada host en la porción de TPU. La llamada a jax.device_count() deja de responder hasta que se llama a cada host en la porción. En el siguiente ejemplo, se ilustra cómo ejecutar un cálculo de JAX en una porción de TPU.

Prepara el código

Necesitas la versión gcloud >= 344.0.0 (para el comando scp). Usa gcloud --version para verificar tu versión de gcloud y ejecuta gcloud components upgrade, si es necesario.

Crea un archivo llamado example.py con el siguiente código:


import jax

# The total number of TPU cores in the slice
device_count = jax.device_count()

# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()

# The psum is performed over all mapped devices across the slice
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)

# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
    print('global device count:', jax.device_count())
    print('local device count:', jax.local_device_count())
    print('pmap result:', r)

Copia example.py en todas las VMs de trabajadores de TPU de la porción

$ gcloud compute tpus tpu-vm scp ./example.py ${TPU_NAME}: \
  --worker=all \
  --zone=${ZONE} \
  --project=${PROJECT_ID}

Si no usaste el comando scp anteriormente, es posible que veas un error similar al siguiente:

ERROR: (gcloud.alpha.compute.tpus.tpu-vm.scp) SSH Key is not present in the SSH
agent. Please run `ssh-add /.../.ssh/google_compute_engine` to add it, and try
again.

Para resolver el error, ejecuta el comando ssh-add como se muestra en el mensaje de error y vuelve a ejecutarlo.

Ejecuta el código en la porción

Inicia el programa example.py en cada VM:

$ gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
  --zone=${ZONE} \
  --project=${PROJECT_ID} \
  --worker=all \
  --command="python3 ./example.py"

Resultado (producido con una porción v4-32):

global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]

Limpia

Cuando termines de usar la VM de TPU, sigue estos pasos para limpiar los recursos.

  1. Desconéctate de la instancia de Compute Engine, si aún no lo hiciste:

    (vm)$ exit

    El mensaje ahora debería mostrar username@projectname, que indica que estás en Cloud Shell.

  2. Borra tus recursos de Cloud TPU y Compute Engine.

    $ gcloud compute tpus tpu-vm delete ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID}
  3. Ejecuta gcloud compute tpus execution-groups list para verificar que los recursos se hayan borrado. La eliminación puede tardar varios minutos. El resultado del siguiente comando no debe incluir ninguno de los recursos creados en este instructivo:

    $ gcloud compute tpus tpu-vm list --zone=${ZONE} \
    --project=${PROJECT_ID}