Restez organisé à l'aide des collections
Enregistrez et classez les contenus selon vos préférences.
Exécuter un calcul sur une VM Cloud TPU à l'aide de JAX
Ce document présente brièvement l'utilisation de JAX et de Cloud TPU.
Avant de commencer
Avant d'exécuter les commandes de ce document, vous devez créer un compte Google Cloud, installer Google Cloud CLI et configurer la commande gcloud. Pour en savoir plus, consultez la section Configurer l'environnement Cloud TPU.
Créer une VM Cloud TPU à l'aide de gcloud
Définissez des variables d'environnement pour faciliter l'utilisation des commandes.
L'ID de votre Google Cloud projet. Utilisez un projet existant ou créez-en un.
TPU_NAME
Nom du TPU.
ZONE
Zone dans laquelle créer la VM TPU. Pour en savoir plus sur les zones compatibles, consultez la section Régions et zones de TPU.
ACCELERATOR_TYPE
Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez la section Versions de TPU.
Créez votre VM TPU en exécutant la commande suivante à partir d'un environnement Cloud Shell ou du terminal d'ordinateur sur lequel la Google Cloud CLI est installée.
Si vous ne parvenez pas à vous connecter à une VM TPU à l'aide de SSH, cela peut être dû au fait qu'elle ne possède pas d'adresse IP externe. Pour accéder à une VM TPU sans adresse IP externe, suivez les instructions de la section Se connecter à une VM TPU sans adresse IP publique.
Vérifiez que JAX peut accéder au TPU et exécuter des opérations de base:
Démarrez l'interpréteur Python 3 :
(vm)$python3
>>>importjax
Affichez le nombre de cœurs de TPU disponibles :
>>>jax.device_count()
Le nombre de cœurs de TPU est affiché. Le nombre de cœurs affiché dépend de la version de TPU que vous utilisez. Pour en savoir plus, consultez la section Versions de TPU.
Effectuer un calcul
>>>jax.numpy.add(1,1)
Le résultat de l'ajout de Numpy s'affiche :
Sortie de la commande :
Array(2,dtype=int32,weak_type=True)
Quitter l'interpréteur Python
>>>exit()
Exécuter du code JAX sur une VM TPU
Vous pouvez maintenant exécuter n'importe quel code JAX. Les exemples de type flax constituent un bon point de départ pour exécuter des modèles de ML standards dans JAX. Par exemple, pour entraîner un réseau convolutif MNIST de base:
Vérifiez que les ressources ont bien été supprimées en exécutant la commande suivante. Assurez-vous que votre TPU n'est plus répertorié. La suppression peut prendre plusieurs minutes.
$gcloudcomputetpustpu-vmlist\--zone=$ZONE
Remarques sur les performances
Voici quelques informations importantes, particulièrement pertinentes pour l'utilisation de TPU dans JAX.
Remplissage
L'une des causes les plus courantes de ralentissement des performances sur les TPU consiste à introduire une marge intérieure inattendue :
Les tableaux dans Cloud TPU sont tuilés. Cela implique de compléter l'une des dimensions jusqu'à un multiple de 8, et une autre jusqu'à un multiple de 128.
L'unité de multiplication matricielle fonctionne mieux avec des paires de matrices volumineuses qui minimisent le besoin de remplissage.
bfloat16 dtype
Par défaut, la multiplication matricielle dans JAX sur les TPU utilise bfloat16 avec l'accumulation float32. Elle peut être contrôlée à l'aide de l'argument de précision pour les appels de fonction jax.numpy pertinents (matmul, point, einsum, etc.). En particulier :
precision=jax.lax.Precision.DEFAULT: utilise la précision bfloat16 mixte (la plus rapide).
precision=jax.lax.Precision.HIGH: utilise plusieurs passes MXU pour obtenir une précision plus élevée.
precision=jax.lax.Precision.HIGHEST: utilise encore plus de passes MXU pour obtenir une précision float32 complète.
JAX ajoute également le dtype bfloat16, que vous pouvez utiliser pour caster explicitement des tableaux dans bfloat16. Exemple : jax.numpy.array(x, dtype=jax.numpy.bfloat16).
Étapes suivantes
Pour en savoir plus sur Cloud TPU, consultez les pages suivantes :
Sauf indication contraire, le contenu de cette page est régi par une licence Creative Commons Attribution 4.0, et les échantillons de code sont régis par une licence Apache 2.0. Pour en savoir plus, consultez les Règles du site Google Developers. Java est une marque déposée d'Oracle et/ou de ses sociétés affiliées.
Dernière mise à jour le 2025/09/04 (UTC).
[[["Facile à comprendre","easyToUnderstand","thumb-up"],["J'ai pu résoudre mon problème","solvedMyProblem","thumb-up"],["Autre","otherUp","thumb-up"]],[["Difficile à comprendre","hardToUnderstand","thumb-down"],["Informations ou exemple de code incorrects","incorrectInformationOrSampleCode","thumb-down"],["Il n'y a pas l'information/les exemples dont j'ai besoin","missingTheInformationSamplesINeed","thumb-down"],["Problème de traduction","translationIssue","thumb-down"],["Autre","otherDown","thumb-down"]],["Dernière mise à jour le 2025/09/04 (UTC)."],[],[],null,["# Run a calculation on a Cloud TPU VM using JAX\n=============================================\n\nThis document provides a brief introduction to working with JAX and Cloud TPU.\n| **Note:** This example shows how to run code on a v5litepod-8 (v5e) TPU which is a single-host TPU. Single-host TPUs have only 1 TPU VM. To run code on TPUs with more than one TPU VM (for example, v5litepod-16 or larger), see [Run JAX code on Cloud TPU slices](/tpu/docs/jax-pods).\n\n\nBefore you begin\n----------------\n\nBefore running the commands in this document, you must create a Google Cloud\naccount, install the Google Cloud CLI, and configure the `gcloud` command. For\nmore information, see [Set up the Cloud TPU environment](/tpu/docs/setup-gcp-account).\n\nCreate a Cloud TPU VM using `gcloud`\n------------------------------------\n\n1. Define some environment variables to make commands easier to use.\n\n\n ```bash\n export PROJECT_ID=your-project-id\n export TPU_NAME=your-tpu-name\n export ZONE=us-east5-a\n export ACCELERATOR_TYPE=v5litepod-8\n export RUNTIME_VERSION=v2-alpha-tpuv5-lite\n ``` \n\n #### Environment variable descriptions\n\n \u003cbr /\u003e\n\n2. Create your TPU VM by running the following command from a Cloud Shell or\n your computer terminal where the [Google Cloud CLI](/sdk/docs/install)\n is installed.\n\n ```bash\n $ gcloud compute tpus tpu-vm create $TPU_NAME \\\n --project=$PROJECT_ID \\\n --zone=$ZONE \\\n --accelerator-type=$ACCELERATOR_TYPE \\\n --version=$RUNTIME_VERSION\n ```\n\nConnect to your Cloud TPU VM\n----------------------------\n\nConnect to your TPU VM over SSH by using the following command: \n\n```bash\n$ gcloud compute tpus tpu-vm ssh $TPU_NAME \\\n --project=$PROJECT_ID \\\n --zone=$ZONE\n```\n\nIf you fail to connect to a TPU VM using SSH, it might be because the TPU VM\ndoesn't have an external IP address. To access a TPU VM without an external IP\naddress, follow the instructions in [Connect to a TPU VM without a public IP\naddress](/tpu/docs/tpu-iap).\n\nInstall JAX on your Cloud TPU VM\n--------------------------------\n\n```bash\n(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html\n```\n\nSystem check\n------------\n\nVerify that JAX can access the TPU and can run basic operations:\n\n1. Start the Python 3 interpreter:\n\n ```bash\n (vm)$ python3\n ``` \n\n ```bash\n \u003e\u003e\u003e import jax\n ```\n2. Display the number of TPU cores available:\n\n ```bash\n \u003e\u003e\u003e jax.device_count()\n ```\n\nThe number of TPU cores is displayed. The number of cores displayed is dependent\non the TPU version you are using. For more information, see [TPU versions](/tpu/docs/system-architecture-tpu-vm#versions).\n\n### Perform a calculation\n\n```bash\n\u003e\u003e\u003e jax.numpy.add(1, 1)\n```\n\nThe result of the numpy add is displayed:\n\nOutput from the command: \n\n```bash\nArray(2, dtype=int32, weak_type=True)\n```\n\n\u003cbr /\u003e\n\n### Exit the Python interpreter\n\n```bash\n\u003e\u003e\u003e exit()\n```\n\nRunning JAX code on a TPU VM\n----------------------------\n\nYou can now run any JAX code you want. The [Flax examples](https://github.com/google/flax/tree/master/examples)\nare a great place to start with running standard ML models in JAX. For example,\nto train a basic MNIST convolutional network:\n\n1. Install Flax examples dependencies:\n\n ```bash\n (vm)$ pip install --upgrade clu\n (vm)$ pip install tensorflow\n (vm)$ pip install tensorflow_datasets\n ```\n2. Install Flax:\n\n ```bash\n (vm)$ git clone https://github.com/google/flax.git\n (vm)$ pip install --user flax\n ```\n3. Run the Flax MNIST training script:\n\n ```bash\n (vm)$ cd flax/examples/mnist\n (vm)$ python3 main.py --workdir=/tmp/mnist \\\n --config=configs/default.py \\\n --config.learning_rate=0.05 \\\n --config.num_epochs=5\n ```\n\nThe script downloads the dataset and starts training. The script output should\nlook like this: \n\n```bash\nI0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88\nI0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72\nI0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04\nI0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15\nI0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18\n```\n\n\nClean up\n--------\n\n\nTo avoid incurring charges to your Google Cloud account for\nthe resources used on this page, follow these steps.\n\nWhen you are done with your TPU VM, follow these steps to clean up your resources.\n\n1. Disconnect from the Cloud TPU instance, if you have not already done so:\n\n ```bash\n (vm)$ exit\n ```\n\n Your prompt should now be username@projectname, showing you are in the Cloud Shell.\n2. Delete your Cloud TPU:\n\n ```bash\n $ gcloud compute tpus tpu-vm delete $TPU_NAME \\\n --project=$PROJECT_ID \\\n --zone=$ZONE\n ```\n3. Verify the resources have been deleted by running the following command. Make\n sure your TPU is no longer listed. The deletion might take several minutes.\n\n ```bash\n $ gcloud compute tpus tpu-vm list \\\n --zone=$ZONE\n ```\n\nPerformance notes\n-----------------\n\nHere are a few important details that are particularly relevant to using TPUs in\nJAX.\n\n### Padding\n\nOne of the most common causes for slow performance on TPUs is introducing\ninadvertent padding:\n\n- Arrays in the Cloud TPU are tiled. This entails padding one of the dimensions to a multiple of 8, and a different dimension to a multiple of 128.\n- The matrix multiplication unit performs best with pairs of large matrices that minimize the need for padding.\n\n### bfloat16 dtype\n\nBy default, matrix multiplication in JAX on TPUs uses [bfloat16](/tpu/docs/bfloat16)\nwith float32 accumulation. This can be controlled with the precision argument on\nrelevant `jax.numpy` function calls (matmul, dot, einsum, etc). In particular:\n\n- `precision=jax.lax.Precision.DEFAULT`: uses mixed bfloat16 precision (fastest)\n- `precision=jax.lax.Precision.HIGH`: uses multiple MXU passes to achieve higher precision\n- `precision=jax.lax.Precision.HIGHEST`: uses even more MXU passes to achieve full float32 precision\n\nJAX also adds the bfloat16 dtype, which you can use to explicitly cast arrays to\n`bfloat16`. For example,\n`jax.numpy.array(x, dtype=jax.numpy.bfloat16)`.\n\n\nWhat's next\n-----------\n\nFor more information about Cloud TPU, see:\n\n- [Run JAX code on TPU slices](/tpu/docs/jax-pods)\n- [Manage TPUs](/tpu/docs/managing-tpus-tpu-vm)\n- [Cloud TPU System architecture](/tpu/docs/system-architecture-tpu-vm)"]]