Ejecuta un cálculo en una VM de Cloud TPU con JAX
En este documento, se proporciona una breve introducción al trabajo con JAX y Cloud TPU.
Antes de seguir esta guía de inicio rápido, debes crear una cuenta de Google Cloud Platform, instalar Google Cloud CLI y configurar el comando gcloud
.
Para obtener más información, consulta Configura una cuenta y un proyecto de Cloud TPU.
Instala Google Cloud CLI
Google Cloud CLI contiene herramientas y bibliotecas para interactuar con los productos y servicios de Google Cloud. Para obtener más información, consulta Instala Google Cloud CLI.
Configura el comando gcloud
Ejecuta los siguientes comandos a fin de configurar gcloud
para usar tu proyecto de Google Cloud y, luego, instalar los componentes necesarios para la vista previa de la VM de TPU.
$ gcloud config set account your-email-account $ gcloud config set project your-project-id
Habilita la API de Cloud TPU
Habilita la API de Cloud TPU con el siguiente comando
gcloud
en Cloud Shell. También puedes habilitarla en la consola de Google Cloud.$ gcloud services enable tpu.googleapis.com
Ejecuta el siguiente comando para crear una identidad de servicio.
$ gcloud beta services identity create --service tpu.googleapis.com
Crea una VM de Cloud TPU con gcloud
Con las VMs de Cloud TPU, tu modelo y código se ejecutan directamente en la máquina anfitrión de TPU. Establece una conexión SSH directamente en el host de TPU. Puedes ejecutar código arbitrario, instalar paquetes, ver registros y depurar código directamente en el host de TPU.
Para crear tu VM de TPU, ejecuta el siguiente comando desde Cloud Shell o la terminal de tu computadora en la que está instalado Google Cloud CLI.
(vm)$ gcloud compute tpus tpu-vm create tpu-name \ --zone=us-central2-b \ --accelerator-type=v4-8 \ --version=tpu-ubuntu2204-base
Campos obligatorios
zone
- La zona en la que planeas crear tu Cloud TPU.
accelerator-type
- El tipo de acelerador especifica la versión y el tamaño de la Cloud TPU que quieres crear. Si quieres obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
version
- La versión de software de Cloud TPU. Para todos los tipos de TPU, usa
tpu-ubuntu2204-base
.
Conéctate a tu VM de Cloud TPU
Establece una conexión SSH a la VM de TPU con el siguiente comando:
$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central2-b
Campos obligatorios
tpu_name
- El nombre de la VM de TPU a la que te conectas.
zone
- La zona en la que creaste tu Cloud TPU.
Instala JAX en tu VM de Cloud TPU
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Verificación del sistema
Verifica que JAX pueda acceder a la TPU y ejecutar operaciones básicas:
Inicia el intérprete de Python 3:
(vm)$ python3
>>> import jax
Muestra la cantidad de núcleos de TPU disponibles:
>>> jax.device_count()
Se muestra la cantidad de núcleos de TPU. Si usas una TPU v4, debería ser 4
. Si usas una TPU v2 o v3, debería ser 8
.
Realiza un cálculo simple:
>>> jax.numpy.add(1, 1)
Se muestra el resultado de la adición de NumPy:
Resultado del comando:
Array(2, dtype=int32, weak_type=true)
Sal del intérprete de Python:
>>> exit()
Ejecuta el código JAX en una VM de TPU
Ahora puedes ejecutar cualquier código JAX que desees. Los ejemplos de flax son un excelente lugar para comenzar a ejecutar modelos de AA estándar en JAX. Por ejemplo, para entrenar una red convolucional de MNIST básica, haz lo siguiente:
Instala las dependencias de los ejemplos de Flax
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Instalar FLAX
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Ejecuta la secuencia de comandos de entrenamiento de FLAX MNIST
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
La secuencia de comandos descargará el conjunto de datos y comenzará el entrenamiento. El resultado de la secuencia de comandos debería verse de la siguiente manera:
0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
Limpia
Cuando termines de usar la VM de TPU, sigue estos pasos para limpiar los recursos.
Desconéctate de la instancia de Compute Engine, si aún no lo hiciste:
(vm)$ exit
Borra tu Cloud TPU.
$ gcloud compute tpus tpu-vm delete tpu-name \ --zone=us-central2-b
Ejecuta el siguiente comando para verificar que se hayan borrado los recursos. Asegúrate de que tu TPU ya no aparezca en la lista. La eliminación puede tardar varios minutos.
$ gcloud compute tpus tpu-vm list \ --zone=us-central2-b
Notas de rendimiento
Estos son algunos detalles importantes que son particularmente relevantes para el uso de TPU en JAX.
Relleno
Una de las causas más comunes del rendimiento lento en las TPU es la introducción de padding involuntario:
- Los arreglos en Cloud TPU están en mosaicos. Esto implica rellenar una de las dimensiones a un múltiplo de 8 y otra diferente a un múltiplo de 128.
- La unidad de multiplicación de matrices funciona mejor con pares de matrices grandes que minimizan la necesidad de rellenar.
bfloat16 dtype
De forma predeterminada, la multiplicación de matrices en JAX en TPU usa bfloat16 con acumulación float32. Esto se puede controlar con el argumento de precisión en llamadas a la función jax.numpy relevantes (matmul, punto, einsum, etcétera). En particular, considera lo siguiente:
precision=jax.lax.Precision.DEFAULT
: Usa una precisión de bfloat16 mixta (más rápida)precision=jax.lax.Precision.HIGH
: Usa varios pases MXU para lograr una mayor precisiónprecision=jax.lax.Precision.HIGHEST
: Usa incluso más pases MXU para lograr la precisión completa de float32
JAX también agrega el dtype bfloat16, que puedes usar para convertir arrays de manera explícita en bfloat16
, por ejemplo, jax.numpy.array(x, dtype=jax.numpy.bfloat16)
.
Ejecuta JAX en una Colab
Cuando ejecutas código JAX en un notebook de Colab, Colab crea automáticamente un nodo TPU heredado. Los nodos TPU tienen una arquitectura diferente. Para obtener más información, consulta Arquitectura del sistema.
¿Qué sigue?
Para obtener más información sobre Cloud TPU, consulta los siguientes vínculos:
- Ejecuta código JAX en porciones de pod de TPU
- Administra las TPU
- Arquitectura de sistemas de Cloud TPU