Ejecutar un cálculo en una VM de TPU de Cloud con JAX

En este documento se ofrece una breve introducción sobre cómo trabajar con JAX y las TPUs de Cloud.

Antes de empezar

Antes de ejecutar los comandos de este documento, debes crear una cuenta, instalar la CLI de Google Cloud y configurar el comando gcloud. Google CloudPara obtener más información, consulta Configurar el entorno de TPU de Cloud.

Crea una VM de TPU de Cloud con gcloud.

  1. Define algunas variables de entorno para que los comandos sean más fáciles de usar.

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east5-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite

    Descripciones de las variables de entorno

    Variable Descripción
    PROJECT_ID El ID de tu proyecto Google Cloud . Usa un proyecto que ya tengas o crea uno.
    TPU_NAME El nombre de la TPU.
    ZONE La zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de la TPU de Cloud que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    RUNTIME_VERSION La versión de software de la TPU de Cloud.

  2. Crea tu VM de TPU ejecutando el siguiente comando desde Cloud Shell o desde el terminal de tu ordenador, donde esté instalada la CLI de Google Cloud.

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$RUNTIME_VERSION

Conéctate a tu VM de TPU de Cloud

Conéctate a tu VM de TPU a través de SSH con el siguiente comando:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE

Si no puedes conectarte a una VM de TPU mediante SSH, puede deberse a que la VM de TPU no tiene una dirección IP externa. Para acceder a una VM de TPU sin una dirección IP externa, sigue las instrucciones que se indican en Conectarse a una VM de TPU sin una dirección IP pública.

Instalar JAX en tu máquina virtual de TPU de Cloud

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Prueba del sistema

Verifica que JAX puede acceder a la TPU y ejecutar operaciones básicas:

  1. Inicia el intérprete de Python 3:

    (vm)$ python3
    >>> import jax
  2. Muestra el número de núcleos de TPU disponibles:

    >>> jax.device_count()

Se muestra el número de núcleos de TPU. El número de núcleos que se muestra depende de la versión de TPU que estés usando. Para obtener más información, consulta Versiones de TPU.

Hacer un cálculo

>>> jax.numpy.add(1, 1)

Se muestra el resultado de la suma de NumPy:

Resultado del comando:

Array(2, dtype=int32, weak_type=True)

Salir del intérprete de Python

>>> exit()

Ejecutar código JAX en una VM de TPU

Ahora puedes ejecutar el código JAX que quieras. Los ejemplos de Flax son un buen punto de partida para ejecutar modelos de aprendizaje automático estándar en JAX. Por ejemplo, para entrenar una red convolucional MNIST básica, haz lo siguiente:

  1. Instala las dependencias de los ejemplos de Flax:

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. Instalar Flax:

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. Ejecuta la secuencia de comandos de entrenamiento de MNIST de Flax:

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
        --config=configs/default.py \
        --config.learning_rate=0.05 \
        --config.num_epochs=5

La secuencia de comandos descarga el conjunto de datos e inicia el entrenamiento. La salida de la secuencia de comandos debería tener este aspecto:

I0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Limpieza

Para evitar que se apliquen cargos en tu cuenta de Google Cloud por los recursos utilizados en esta página, sigue estos pasos.

Cuando hayas terminado de usar tu VM de TPU, sigue estos pasos para limpiar tus recursos.

  1. Desconéctate de la instancia de TPU de Cloud (si aún no lo has hecho):

    (vm)$ exit

    A continuación, se mostrará el mensaje nombredeusuario@nombreproyecto, que indica que estás en Cloud Shell.

  2. Elimina tu TPU de Cloud:

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
        --project=$PROJECT_ID \
        --zone=$ZONE
  3. Para comprobar que los recursos se han eliminado, ejecuta el siguiente comando. Comprueba que tu TPU ya no aparezca en la lista. El proceso de eliminación puede tardar varios minutos.

    $ gcloud compute tpus tpu-vm list \
        --zone=$ZONE

Notas sobre el rendimiento

A continuación, se indican algunos detalles importantes que son especialmente relevantes para usar las TPUs en JAX.

Relleno

Una de las causas más habituales de que las TPUs tengan un rendimiento lento es la introducción de relleno por error:

  • Las matrices de la TPU de Cloud se organizan en mosaicos. Para ello, se debe añadir relleno a una de las dimensiones hasta que sea múltiplo de 8 y a la otra dimensión hasta que sea múltiplo de 128.
  • La unidad de multiplicación de matrices funciona mejor con pares de matrices grandes que minimizan la necesidad de relleno.

Tipo de datos bfloat16

De forma predeterminada, la multiplicación de matrices en JAX en TPUs usa bfloat16 con acumulación de float32. Esto se puede controlar con el argumento de precisión en las llamadas de función jax.numpy pertinentes (matmul, dot, einsum, etc.). En particular:

  • precision=jax.lax.Precision.DEFAULT: usa precisión mixta bfloat16 (la más rápida)
  • precision=jax.lax.Precision.HIGH: usa varias pasadas de MXU para conseguir una mayor precisión.
  • precision=jax.lax.Precision.HIGHEST: usa aún más pases de MXU para conseguir una precisión float32 completa.

JAX también añade el tipo de datos bfloat16, que puedes usar para convertir explícitamente las matrices a bfloat16. Por ejemplo, jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Siguientes pasos

Para obtener más información sobre las TPU de Cloud, consulta los siguientes recursos: