Guía de inicio rápido sobre JAX y las VM de Cloud TPU

En este documento, se proporciona una breve introducción al trabajo con AJAX y Cloud TPU.

Antes de seguir esta guía de inicio rápido, debes crear una cuenta de Google Cloud Platform y, luego, instalar el SDK de Cloud. y configura el comando gcloud. Para obtener más información, consulta Configura una cuenta y un proyecto de Cloud TPU.

Instala el SDK de Cloud

El SDK de Cloud contiene herramientas y bibliotecas para interactuar con los productos y servicios de Google Cloud. Para obtener más información, consulta Instala el SDK de Cloud.

Configura el comando gcloud

Ejecuta los siguientes comandos a fin de configurar gcloud con el fin de usar tu proyecto de GCP e instalar los componentes necesarios para la vista previa de la VM de TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project project-id

Habilita la API de Cloud TPU

  1. Habilita la API de Cloud TPU con el siguiente comando de gcloud en Cloud Shell. (También puedes habilitarlo desde Google Cloud Console).

    $ gcloud services enable tpu.googleapis.com
    
  2. Ejecuta el siguiente comando para crear una identidad de servicio.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Crea una VM de Cloud TPU con gcloud

Con las VM de Cloud TPU, el modelo y el código se ejecutan directamente en la máquina anfitrión de TPU. Establece una conexión SSH directamente en el host de TPU. Puedes ejecutar código arbitrario, instalar paquetes, ver registros y depurar código directamente en el host de TPU.

  1. Para crear tu VM de TPU, ejecuta el siguiente comando desde un Cloud Shell de GCP o tu terminal de computadora en la que esté instalado el SDK de Cloud.

    (vm)$  gcloud alpha compute tpus tpu-vm create tpu-name \
    --zone europe-west4-a \
    --accelerator-type v3-8 \
    --version v2-alpha

    Campos obligatorios

    zone
    Es la zona en la que deseas crear la Cloud TPU.
    accelerator-type
    El tipo de Cloud TPU que se creará.
    version
    La versión del entorno de ejecución de Cloud TPU. Configúralo como “v2-alpha” cuando uses JAX en dispositivos de TPU únicos, secciones de pod o pods completos.

Conéctate a tu VM de Cloud TPU

Establece una conexión SSH a tu VM de TPU con el siguiente comando:

$  gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a

Campos obligatorios

tpu_name
El nombre de la VM de TPU a la que te estás conectando.
zone
La zona en la que creaste tu Cloud TPU.

Instala JAX en tu VM de Cloud TPU

(vm)$ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Verificación del sistema

Para verificar que todo esté instalado correctamente, verifica que JAX vea los núcleos de Cloud TPU y pueda ejecutar operaciones básicas:

Inicia el intérprete de Python 3:

(vm)$ python3
>>> import jax

Muestra la cantidad de núcleos de TPU disponibles:

>>> jax.device_count()

Se muestra la cantidad de núcleos de TPU, que debe ser 8

Realiza un cálculo simple:

>>> jax.numpy.add(1, 1)

Se muestra el resultado de la adición de NumPy:

Resultado del comando:

DeviceArray(2, dtype=int32)

Sal del intérprete de Python:

>>> exit()

Ejecuta el código JAX en una VM de TPU

Ahora puedes ejecutar el código JAX que quieras. Los ejemplos de flax son un excelente punto de partida para ejecutar modelos de AA estándar en JAX. Por ejemplo, para entrenar una red convolucional básica de MNIST, sigue estos pasos:

  1. Instala conjuntos de datos de TensorFlow

    (vm)$ pip install --upgrade clu
    
  2. Instala FLAX.

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user -e flax
    
  3. Ejecuta la secuencia de comandos de entrenamiento de FLAX MNIST

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

    El resultado de la secuencia de comandos debería ser similar al siguiente:

    I0513 21:09:35.448946 140431261813824 train.py:125] train epoch: 1, loss: 0.2312, accuracy: 93.00
    I0513 21:09:36.402860 140431261813824 train.py:176] eval epoch: 1, loss: 0.0563, accuracy: 98.05
    I0513 21:09:37.321380

Realice una limpieza

Cuando termines con la VM de TPU, sigue estos pasos para limpiar los recursos.

  1. Desconéctate de la instancia de Compute Engine, si aún no lo hiciste:

    (vm)$ exit
    
  2. Borra tu Cloud TPU.

    $  gcloud alpha compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. Ejecuta el siguiente comando para verificar que los recursos se hayan borrado. Asegúrate de que la TPU ya no esté en la lista. La eliminación puede tardar varios minutos.

Notas de rendimiento

A continuación, se indican algunos detalles importantes que son particularmente relevantes para el uso de TPU en AJAX.

Relleno

Una de las causas más comunes del rendimiento lento en las TPU es el relleno involuntario:

  • Los arreglos en Cloud TPU están en mosaicos. Esto implica el relleno de una de las dimensiones a un múltiplo de 8 y de otra a un múltiplo de 128.
  • La unidad de multiplicación de matrices funciona mejor con pares de matrices grandes que minimizan la necesidad de relleno.

bfloat16 dtype

De forma predeterminada, la multiplicación de matrices en JAX en TPU usa bfloat16 con acumulación de float32. Esto se puede controlar con el argumento de precisión en llamadas a funciones jax.numpy relevantes (matmul, punto, einsum, etcétera). En particular, considera lo siguiente:

  • precision=jax.lax.Precision.DEFAULT: usa una precisión de bfloat16 mixta (más rápida)
  • precision=jax.lax.Precision.HIGH: Usa varios pases MXU para lograr una mayor precisión
  • precision=jax.lax.Precision.HIGHEST: usa aún más pases MXU para lograr una precisión total de float32

JAX también agrega el dtype bfloat16, que puedes usar para convertir arreglos de forma explícita en bfloat16. P. ej., jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Ejecuta JAX en una Colab

Cuando ejecutas código JAX en un notebook de Colab, Colab crea un nodo TPU heredado de forma automática. Los nodos TPU tienen una arquitectura diferente. Para obtener más información, consulta Arquitectura del sistema.