Ejecuta un cálculo en una VM de Cloud TPU con JAX
En este documento, se proporciona una breve introducción al trabajo con JAX y Cloud TPU.
Antes de comenzar
Antes de ejecutar los comandos de este documento, debes crear una cuenta de Google Cloud, instalar Google Cloud CLI y configurar el comando gcloud
. Para obtener más información, consulta Configura el entorno de Cloud TPU.
Crea una VM de Cloud TPU con gcloud
Define algunas variables de entorno para facilitar el uso de los comandos.
export PROJECT_ID=your-project export ACCELERATOR_TYPE=v5p-8 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export TPU_NAME=your-tpu-name
Descripciones de las variables de entorno
PROJECT_ID
- El Google Cloud ID de tu proyecto.
ACCELERATOR_TYPE
- El tipo de acelerador especifica la versión y el tamaño de la Cloud TPU que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles para cada versión de TPU, consulta Versiones de TPU.
ZONE
- Es la zona en la que deseas crear la Cloud TPU.
RUNTIME_VERSION
- La versión del entorno de ejecución de Cloud TPU. Para obtener más información, consulta Imágenes de VM de TPU.
TPU_NAME
- El nombre asignado por el usuario a tu Cloud TPU.
Para crear tu VM de TPU, ejecuta el siguiente comando desde Cloud Shell o la terminal de tu computadora en la que esté instalada la CLI de Google Cloud.
$ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Conéctate a tu VM de Cloud TPU
Usa el siguiente comando para conectarte a tu VM de TPU a través de SSH:
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Instala JAX en tu VM de Cloud TPU
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Verificación del sistema
Verifica que JAX pueda acceder a la TPU y ejecutar operaciones básicas:
Inicia el intérprete de Python 3:
(vm)$ python3
>>> import jax
Muestra la cantidad de núcleos de TPU disponibles:
>>> jax.device_count()
Se muestra la cantidad de núcleos de TPU. La cantidad de núcleos que se muestran depende de la versión de TPU que usas. Para obtener más información, consulta Versiones de TPU.
Realiza un cálculo:
>>> jax.numpy.add(1, 1)
Se muestra el resultado de la adición de numpy:
Resultado del comando:
Array(2, dtype=int32, weak_type=true)
Sal del intérprete de Python:
>>> exit()
Ejecuta el código JAX en una VM de TPU
Ahora puedes ejecutar cualquier código JAX que desees. Los ejemplos de flax son un excelente punto de partida para ejecutar modelos de AA estándar en JAX. Por ejemplo, para entrenar una red convolucional básica de MNIST, haz lo siguiente:
Instala las dependencias de los ejemplos de Flax
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Instala FLAX
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Ejecuta la secuencia de comandos de entrenamiento de MNIST de FLAX
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
La secuencia de comandos descarga el conjunto de datos y comienza el entrenamiento. El resultado de la secuencia de comandos debería verse de la siguiente manera:
0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
Limpia
Para evitar que se apliquen cargos a tu Google Cloud cuenta por los recursos que usaste en esta página, sigue estos pasos.
Cuando termines de usar la VM de TPU, sigue estos pasos para limpiar los recursos.
Desconéctate de la instancia de Compute Engine, si aún no lo hiciste:
(vm)$ exit
Borra tu Cloud TPU.
$ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Ejecuta el siguiente comando para verificar que los recursos se hayan borrado. Asegúrate de que tu TPU ya no aparezca en la lista. La eliminación puede tardar varios minutos.
$ gcloud compute tpus tpu-vm list \ --zone=$ZONE
Notas de rendimiento
Estos son algunos detalles importantes que son particularmente relevantes para usar las TPU en JAX.
Relleno
Una de las causas más comunes del rendimiento lento en las TPU es la introducción de relleno imprevisto:
- Los arreglos en Cloud TPU están en mosaicos. Esto implica el relleno de una de las dimensiones a un múltiplo de 8 y de otra a un múltiplo de 128.
- La unidad de multiplicación de matrices funciona mejor con pares de matrices grandes que minimizan la necesidad de relleno.
bfloat16 dtype
De forma predeterminada, la multiplicación de matrices en JAX en TPU usa bfloat16 con acumulación de float32. Esto se puede controlar con el argumento de precisión en las llamadas a funciones relevantes de jax.numpy (matmul, dot, einsum, etc.). En particular:
precision=jax.lax.Precision.DEFAULT
: Usa precisión bfloat16 mixta (más rápida).precision=jax.lax.Precision.HIGH
: Usa varios pases de MXU para lograr una mayor precisión.precision=jax.lax.Precision.HIGHEST
: Usa aún más pases de MXU para lograr una precisión completa de float32.
JAX también agrega el tipo de datos bfloat16, que puedes usar para transmitir arrays de forma explícita a bfloat16
, por ejemplo, jax.numpy.array(x, dtype=jax.numpy.bfloat16)
.
¿Qué sigue?
Para obtener más información sobre Cloud TPU, consulta los siguientes vínculos:
- Ejecuta el código JAX en porciones de pod de TPU
- Administra TPU
- Arquitectura de sistemas de Cloud TPU