Organiza tus páginas con colecciones
Guarda y categoriza el contenido según tus preferencias.
Ejecuta un cálculo en una VM de Cloud TPU con JAX
En este documento, se proporciona una breve introducción al trabajo con JAX y Cloud TPU.
Antes de comenzar
Antes de ejecutar los comandos de este documento, debes crear una cuenta de Google Cloud, instalar Google Cloud CLI y configurar el comando gcloud. Para obtener más información, consulta Configura el entorno de Cloud TPU.
Crea una VM de Cloud TPU con gcloud
Define algunas variables de entorno para facilitar el uso de los comandos.
El Google Cloud ID de tu proyecto. Usa un proyecto existente o
crea uno nuevo.
TPU_NAME
El nombre de la TPU.
ZONE
Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU.
ACCELERATOR_TYPE
El tipo de acelerador especifica la versión y el tamaño de la Cloud TPU que deseas
crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
Para crear tu VM de TPU, ejecuta el siguiente comando desde Cloud Shell o la terminal de tu computadora en la que esté instalada la CLI de Google Cloud.
Si no puedes conectarte a una VM de TPU mediante SSH, es posible que sea porque la VM de TPU
no tiene una dirección IP externa. Para acceder a una VM de TPU sin una dirección IP externa, sigue las instrucciones que se indican en Cómo conectarse a una VM de TPU sin una dirección IP pública.
Verifica que JAX pueda acceder a la TPU y ejecutar operaciones básicas:
Inicia el intérprete de Python 3:
(vm)$python3
>>>importjax
Muestra la cantidad de núcleos de TPU disponibles:
>>>jax.device_count()
Se muestra la cantidad de núcleos de TPU. La cantidad de núcleos que se muestran depende de la versión de TPU que usas. Para obtener más información, consulta Versiones de TPU.
Cómo realizar un cálculo
>>>jax.numpy.add(1,1)
Se muestra el resultado de la adición de numpy:
Resultado del comando:
Array(2,dtype=int32,weak_type=True)
Sal del intérprete de Python
>>>exit()
Ejecuta el código JAX en una VM de TPU
Ahora puedes ejecutar cualquier código JAX que desees. Los ejemplos de Flax
son un excelente punto de partida para ejecutar modelos de AA estándar en JAX. Por ejemplo, para entrenar una red convolucional básica de MNIST, haz lo siguiente:
La secuencia de comandos descarga el conjunto de datos y comienza el entrenamiento. El resultado de la secuencia de comandos debería verse de la siguiente manera:
Ejecuta el siguiente comando para verificar que los recursos se hayan borrado. Asegúrate de que tu TPU ya no aparezca en la lista. La eliminación puede tardar varios minutos.
$gcloudcomputetpustpu-vmlist\--zone=$ZONE
Notas de rendimiento
Estos son algunos detalles importantes que son particularmente relevantes para usar TPU en JAX.
Relleno
Una de las causas más comunes del rendimiento lento en las TPU es la introducción de relleno imprevisto:
Los arreglos en Cloud TPU están en mosaicos. Esto implica el relleno de una de las dimensiones a un múltiplo de 8 y de otra a un múltiplo de 128.
La unidad de multiplicación de matrices funciona mejor con pares de matrices grandes que minimizan la necesidad de relleno.
bfloat16 dtype
De forma predeterminada, la multiplicación de matrices en JAX en TPU usa bfloat16 con acumulación de float32. Esto se puede controlar con el argumento de precisión en las llamadas a la función jax.numpy relevantes (matmul, dot, einsum, etc.). En particular:
precision=jax.lax.Precision.DEFAULT: Usa precisión bfloat16 mixta (más rápida).
precision=jax.lax.Precision.HIGH: Usa varios pases de MXU para lograr una mayor precisión.
precision=jax.lax.Precision.HIGHEST: Usa aún más pases de MXU para lograr una precisión completa de float32.
JAX también agrega el tipo de datos bfloat16, que puedes usar para transmitir arrays de forma explícita a bfloat16. Por ejemplo, jax.numpy.array(x, dtype=jax.numpy.bfloat16)
¿Qué sigue?
Para obtener más información sobre Cloud TPU, consulta los siguientes vínculos:
[[["Fácil de comprender","easyToUnderstand","thumb-up"],["Resolvió mi problema","solvedMyProblem","thumb-up"],["Otro","otherUp","thumb-up"]],[["Difícil de entender","hardToUnderstand","thumb-down"],["Información o código de muestra incorrectos","incorrectInformationOrSampleCode","thumb-down"],["Faltan la información o los ejemplos que necesito","missingTheInformationSamplesINeed","thumb-down"],["Problema de traducción","translationIssue","thumb-down"],["Otro","otherDown","thumb-down"]],["Última actualización: 2025-09-04 (UTC)"],[],[],null,["# Run a calculation on a Cloud TPU VM using JAX\n=============================================\n\nThis document provides a brief introduction to working with JAX and Cloud TPU.\n| **Note:** This example shows how to run code on a v5litepod-8 (v5e) TPU which is a single-host TPU. Single-host TPUs have only 1 TPU VM. To run code on TPUs with more than one TPU VM (for example, v5litepod-16 or larger), see [Run JAX code on Cloud TPU slices](/tpu/docs/jax-pods).\n\n\nBefore you begin\n----------------\n\nBefore running the commands in this document, you must create a Google Cloud\naccount, install the Google Cloud CLI, and configure the `gcloud` command. For\nmore information, see [Set up the Cloud TPU environment](/tpu/docs/setup-gcp-account).\n\nCreate a Cloud TPU VM using `gcloud`\n------------------------------------\n\n1. Define some environment variables to make commands easier to use.\n\n\n ```bash\n export PROJECT_ID=your-project-id\n export TPU_NAME=your-tpu-name\n export ZONE=us-east5-a\n export ACCELERATOR_TYPE=v5litepod-8\n export RUNTIME_VERSION=v2-alpha-tpuv5-lite\n ``` \n\n #### Environment variable descriptions\n\n \u003cbr /\u003e\n\n2. Create your TPU VM by running the following command from a Cloud Shell or\n your computer terminal where the [Google Cloud CLI](/sdk/docs/install)\n is installed.\n\n ```bash\n $ gcloud compute tpus tpu-vm create $TPU_NAME \\\n --project=$PROJECT_ID \\\n --zone=$ZONE \\\n --accelerator-type=$ACCELERATOR_TYPE \\\n --version=$RUNTIME_VERSION\n ```\n\nConnect to your Cloud TPU VM\n----------------------------\n\nConnect to your TPU VM over SSH by using the following command: \n\n```bash\n$ gcloud compute tpus tpu-vm ssh $TPU_NAME \\\n --project=$PROJECT_ID \\\n --zone=$ZONE\n```\n\nIf you fail to connect to a TPU VM using SSH, it might be because the TPU VM\ndoesn't have an external IP address. To access a TPU VM without an external IP\naddress, follow the instructions in [Connect to a TPU VM without a public IP\naddress](/tpu/docs/tpu-iap).\n\nInstall JAX on your Cloud TPU VM\n--------------------------------\n\n```bash\n(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html\n```\n\nSystem check\n------------\n\nVerify that JAX can access the TPU and can run basic operations:\n\n1. Start the Python 3 interpreter:\n\n ```bash\n (vm)$ python3\n ``` \n\n ```bash\n \u003e\u003e\u003e import jax\n ```\n2. Display the number of TPU cores available:\n\n ```bash\n \u003e\u003e\u003e jax.device_count()\n ```\n\nThe number of TPU cores is displayed. The number of cores displayed is dependent\non the TPU version you are using. For more information, see [TPU versions](/tpu/docs/system-architecture-tpu-vm#versions).\n\n### Perform a calculation\n\n```bash\n\u003e\u003e\u003e jax.numpy.add(1, 1)\n```\n\nThe result of the numpy add is displayed:\n\nOutput from the command: \n\n```bash\nArray(2, dtype=int32, weak_type=True)\n```\n\n\u003cbr /\u003e\n\n### Exit the Python interpreter\n\n```bash\n\u003e\u003e\u003e exit()\n```\n\nRunning JAX code on a TPU VM\n----------------------------\n\nYou can now run any JAX code you want. The [Flax examples](https://github.com/google/flax/tree/master/examples)\nare a great place to start with running standard ML models in JAX. For example,\nto train a basic MNIST convolutional network:\n\n1. Install Flax examples dependencies:\n\n ```bash\n (vm)$ pip install --upgrade clu\n (vm)$ pip install tensorflow\n (vm)$ pip install tensorflow_datasets\n ```\n2. Install Flax:\n\n ```bash\n (vm)$ git clone https://github.com/google/flax.git\n (vm)$ pip install --user flax\n ```\n3. Run the Flax MNIST training script:\n\n ```bash\n (vm)$ cd flax/examples/mnist\n (vm)$ python3 main.py --workdir=/tmp/mnist \\\n --config=configs/default.py \\\n --config.learning_rate=0.05 \\\n --config.num_epochs=5\n ```\n\nThe script downloads the dataset and starts training. The script output should\nlook like this: \n\n```bash\nI0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88\nI0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72\nI0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04\nI0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15\nI0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18\n```\n\n\nClean up\n--------\n\n\nTo avoid incurring charges to your Google Cloud account for\nthe resources used on this page, follow these steps.\n\nWhen you are done with your TPU VM, follow these steps to clean up your resources.\n\n1. Disconnect from the Cloud TPU instance, if you have not already done so:\n\n ```bash\n (vm)$ exit\n ```\n\n Your prompt should now be username@projectname, showing you are in the Cloud Shell.\n2. Delete your Cloud TPU:\n\n ```bash\n $ gcloud compute tpus tpu-vm delete $TPU_NAME \\\n --project=$PROJECT_ID \\\n --zone=$ZONE\n ```\n3. Verify the resources have been deleted by running the following command. Make\n sure your TPU is no longer listed. The deletion might take several minutes.\n\n ```bash\n $ gcloud compute tpus tpu-vm list \\\n --zone=$ZONE\n ```\n\nPerformance notes\n-----------------\n\nHere are a few important details that are particularly relevant to using TPUs in\nJAX.\n\n### Padding\n\nOne of the most common causes for slow performance on TPUs is introducing\ninadvertent padding:\n\n- Arrays in the Cloud TPU are tiled. This entails padding one of the dimensions to a multiple of 8, and a different dimension to a multiple of 128.\n- The matrix multiplication unit performs best with pairs of large matrices that minimize the need for padding.\n\n### bfloat16 dtype\n\nBy default, matrix multiplication in JAX on TPUs uses [bfloat16](/tpu/docs/bfloat16)\nwith float32 accumulation. This can be controlled with the precision argument on\nrelevant `jax.numpy` function calls (matmul, dot, einsum, etc). In particular:\n\n- `precision=jax.lax.Precision.DEFAULT`: uses mixed bfloat16 precision (fastest)\n- `precision=jax.lax.Precision.HIGH`: uses multiple MXU passes to achieve higher precision\n- `precision=jax.lax.Precision.HIGHEST`: uses even more MXU passes to achieve full float32 precision\n\nJAX also adds the bfloat16 dtype, which you can use to explicitly cast arrays to\n`bfloat16`. For example,\n`jax.numpy.array(x, dtype=jax.numpy.bfloat16)`.\n\n\nWhat's next\n-----------\n\nFor more information about Cloud TPU, see:\n\n- [Run JAX code on TPU slices](/tpu/docs/jax-pods)\n- [Manage TPUs](/tpu/docs/managing-tpus-tpu-vm)\n- [Cloud TPU System architecture](/tpu/docs/system-architecture-tpu-vm)"]]