Organiza tus páginas con colecciones Guarda y categoriza el contenido según tus preferencias.
Ejecuta un cálculo en una VM de Cloud TPU con JAX

Ejecutar un cálculo en una VM de Cloud TPU con JAX

En este documento, se proporciona una breve introducción al trabajo con JAX y Cloud TPU.

Antes de seguir esta guía de inicio rápido, debes crear una cuenta de Google Cloud Platform, instalar Google Cloud CLI y configurar el comando gcloud. Para obtener más información, consulta Configura una cuenta y un proyecto de Cloud TPU.

Instala Google Cloud CLI

Google Cloud CLI contiene herramientas y bibliotecas para interactuar con los productos y servicios de Google Cloud. Para obtener más información, consulta Instala Google Cloud CLI.

Configura el comando gcloud

Ejecuta los siguientes comandos a fin de configurar gcloud a fin de usar tu proyecto de Google Cloud y, luego, instalar los componentes necesarios para la vista previa de la VM de TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

Habilita la API de Cloud TPU

  1. Habilita la API de Cloud TPU con el siguiente comando de gcloud en Cloud Shell. (También puedes habilitarla desde Google Cloud Console).

    $ gcloud services enable tpu.googleapis.com
    
  2. Ejecuta el siguiente comando para crear una identidad de servicio.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Crea una VM de Cloud TPU con gcloud

Con las VM de Cloud TPU, el modelo y el código se ejecutan directamente en la máquina host de TPU. Establece una conexión SSH directamente con el host de TPU. Puedes ejecutar código arbitrario, instalar paquetes, ver registros y depurar código directamente en el host de TPU.

  1. Para crear tu VM de TPU, ejecuta el siguiente comando desde Google Cloud Shell o la terminal de tu computadora en la que esté instalada la CLI de Google Cloud.

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone us-central2-b \
    --accelerator-type v4-8 \
    --version tpu-vm-v4-base
    

    Campos obligatorios

    zone
    La zona en la que planeas crear tu Cloud TPU.
    accelerator-type
    El tipo de Cloud TPU que se creará.
    version
    La versión de software de Cloud TPU. Para las TPU v2 y v3, usa tpu-vm-base. Para las TPU v4, usa tpu-vm-v4-base.

Conéctate a tu VM de Cloud TPU

Establece una conexión SSH a tu VM de TPU con el siguiente comando:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone us-central2-b

Campos obligatorios

tpu_name
El nombre de la VM de TPU a la que te conectas.
zone
La zona en la que creaste tu Cloud TPU.

Instala JAX en tu VM de Cloud TPU

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Verificación del sistema

Verifica que JAX pueda acceder a la TPU y ejecutar operaciones básicas:

Inicia el intérprete de Python 3:

(vm)$ python3
>>> import jax

Muestra la cantidad de núcleos de TPU disponibles:

>>> jax.device_count()

Se muestra la cantidad de núcleos de TPU. Si usas una TPU v4, debería ser 4. Si usas una TPU v2 o v3, debería ser 8.

Realiza un cálculo simple:

>>> jax.numpy.add(1, 1)

Se muestra el resultado de numpy add:

Resultado del comando:

Array(2, dtype=int32, weak_type=true)

Sal del intérprete de Python:

>>> exit()

Ejecuta el código JAX en una VM de TPU

Ahora puedes ejecutar cualquier código JAX que desees. Los ejemplos de lint son un excelente punto de partida para ejecutar modelos de AA estándar en JAX. Por ejemplo, para entrenar una red convolucional básica de MNIST, haz lo siguiente:

  1. Instala dependencias de ejemplos de Flax

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
    
  2. Instala FLAX.

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user -e flax
    
  3. Ejecuta la secuencia de comandos de entrenamiento de FLAX MNIST

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

La secuencia de comandos descarga el conjunto de datos y comienza el entrenamiento. El resultado de la secuencia de comandos debería verse así:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Realiza una limpieza

Cuando termines de usar tu VM de TPU, sigue estos pasos para limpiar tus recursos.

  1. Desconéctate de la instancia de Compute Engine, si aún no lo hiciste:

    (vm)$ exit
    
  2. Borra tu Cloud TPU.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone us-central2-b
    
  3. Verifica que se hayan borrado los recursos mediante la ejecución del siguiente comando. Asegúrate de que tu TPU ya no aparezca en la lista. La eliminación puede tardar varios minutos.

    $ gcloud compute tpus tpu-vm list \
      --zone us-central2-b
    

Notas de rendimiento

Estos son algunos detalles importantes que son particularmente relevantes para usar TPU en JAX.

Relleno

Una de las causas más comunes del rendimiento lento en las TPU es introducir padding involuntario:

  • Los arreglos en Cloud TPU están en mosaicos. Esto implica rellenar una de las dimensiones a un múltiplo de 8 y una dimensión diferente a un múltiplo de 128.
  • La unidad de multiplicación de matrices funciona mejor con pares de matrices grandes que minimizan la necesidad de relleno.

bfloat16 dtype

De forma predeterminada, la multiplicación de matrices en JAX en las TPU usa bfloat16 con acumulación de float32. Esto se puede controlar con el argumento de precisión en llamadas relevantes a la función jax.numpy (matmul, punto, einsum, etcétera). En particular, considera lo siguiente:

  • precision=jax.lax.Precision.DEFAULT: usa una precisión mixta de bfloat16 (más rápida).
  • precision=jax.lax.Precision.HIGH: Usa varios pases de MXU para lograr una mayor precisión.
  • precision=jax.lax.Precision.HIGHEST: Usa aún más pases de MXU para lograr una precisión completa de float32.

JAX también agrega el dtype bfloat16, que puedes usar para convertir de forma explícita arreglos en bfloat16, p.ej., jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Ejecuta JAX en una Colab

Cuando ejecutas código JAX en un notebook de Colab, Colab crea automáticamente un nodo de TPU heredado. Los nodos de TPU tienen una arquitectura diferente. Para obtener más información, consulta Arquitectura del sistema.

¿Qué sigue?

Para obtener más información sobre Cloud TPU, consulta los siguientes vínculos: