Exécuter un calcul sur une VM Cloud TPU à l'aide de JAX
Ce document présente brièvement l'utilisation de JAX et de Cloud TPU.
Avant de suivre ce guide de démarrage rapide, vous devez créer un compte Google Cloud Platform, installer Google Cloud CLI et configurer la commande gcloud
.
Pour plus d'informations, consultez la page Configurer un compte et un projet Cloud TPU.
Installer Google Cloud CLI
La CLI Google Cloud contient des outils et des bibliothèques permettant d'interagir avec les produits et services Google Cloud. Pour en savoir plus, consultez la page Installer Google Cloud CLI.
Configurer la commande gcloud
Exécutez les commandes suivantes pour configurer gcloud
afin d'utiliser votre projet Google Cloud et d'installer les composants requis pour la version bêta de la VM TPU.
$ gcloud config set account your-email-account $ gcloud config set project your-project-id
Activer l'API Cloud TPU
Activez l'API Cloud TPU à l'aide de la commande
gcloud
suivante dans Cloud Shell. Vous pouvez également l'activer à partir de Google Cloud Console.$ gcloud services enable tpu.googleapis.com
Exécutez la commande suivante pour créer une identité de service.
$ gcloud beta services identity create --service tpu.googleapis.com
Créer une VM Cloud TPU avec gcloud
Avec les VM Cloud TPU, votre modèle et votre code s'exécutent directement sur la machine hôte TPU. Vous vous connectez directement à l'hôte TPU via SSH. Vous pouvez exécuter du code arbitraire, installer des packages, afficher les journaux et déboguer le code directement sur l'hôte TPU.
Créez votre VM TPU en exécutant la commande suivante à partir d'un environnement Cloud Shell ou du terminal d'ordinateur sur lequel la Google Cloud CLI est installée.
(vm)$ gcloud compute tpus tpu-vm create tpu-name \ --zone=us-central1-a \ --accelerator-type=v3-8 \ --version=tpu-ubuntu2204-base
Champs obligatoires
zone
- Zone dans laquelle vous souhaitez créer votre Cloud TPU.
accelerator-type
- Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez la section Versions de TPU.
version
- Version logicielle de Cloud TPU. Pour tous les types de TPU, utilisez
tpu-ubuntu2204-base
.
Se connecter à la VM Cloud TPU
Connectez-vous en SSH à votre VM TPU à l'aide de la commande suivante :
$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central1-a
Champs obligatoires
tpu_name
- Nom de la VM TPU à laquelle vous vous connectez.
zone
- Zone dans laquelle vous avez créé votre Cloud TPU.
Installer JAX sur votre VM Cloud TPU
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Vérifier le système
Vérifiez que JAX peut accéder au TPU et exécuter des opérations de base:
Démarrez l'interpréteur Python 3 :
(vm)$ python3
>>> import jax
Affichez le nombre de cœurs de TPU disponibles :
>>> jax.device_count()
Le nombre de cœurs de TPU est affiché. Si vous utilisez un TPU v4, cette valeur doit être 4
. Si vous utilisez un TPU v2 ou v3, cette valeur doit être 8
.
Effectuez un calcul simple :
>>> jax.numpy.add(1, 1)
Le résultat de l'ajout de Numpy s'affiche :
Sortie de la commande :
Array(2, dtype=int32, weak_type=true)
Quittez l'interpréteur Python :
>>> exit()
Exécuter du code JAX sur une VM TPU
Vous pouvez maintenant exécuter n'importe quel code JAX. Les exemples de type flax constituent un bon point de départ pour exécuter des modèles de ML standards dans JAX. Par exemple, pour entraîner un réseau convolutif MNIST de base:
Installer les dépendances des exemples Flax
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Installer FLAX
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Exécutez le script d'entraînement FLAX MNIST.
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
Le script télécharge l'ensemble de données et lance l'entraînement. Le résultat du script doit se présenter comme suit:
0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
Effectuer un nettoyage
Lorsque vous avez fini d'utiliser votre VM TPU, procédez comme suit pour nettoyer vos ressources.
Déconnectez-vous de l'instance Compute Engine, si vous ne l'avez pas déjà fait :
(vm)$ exit
Supprimez votre Cloud TPU.
$ gcloud compute tpus tpu-vm delete tpu-name \ --zone=us-central1-a
Vérifiez que les ressources ont bien été supprimées en exécutant la commande suivante. Assurez-vous que votre TPU n'est plus répertorié. La suppression peut prendre plusieurs minutes.
$ gcloud compute tpus tpu-vm list \ --zone=us-central1-a
Remarques sur les performances
Voici quelques informations importantes, particulièrement pertinentes pour l'utilisation de TPU dans JAX.
Remplissage
L'une des causes les plus courantes de ralentissement des performances sur les TPU consiste à introduire une marge intérieure inattendue :
- Les tableaux dans Cloud TPU sont tuilés. Cela implique de compléter l'une des dimensions jusqu'à un multiple de 8, et une autre jusqu'à un multiple de 128.
- L'unité de multiplication matricielle fonctionne mieux avec des paires de matrices volumineuses qui minimisent le besoin de remplissage.
bfloat16 dtype
Par défaut, la multiplication matricielle dans JAX sur les TPU utilise bfloat16 avec l'accumulation float32. Elle peut être contrôlée à l'aide de l'argument de précision pour les appels de fonction jax.numpy pertinents (matmul, point, einsum, etc.). En particulier :
precision=jax.lax.Precision.DEFAULT
: utilise la précision bfloat16 mixte (la plus rapide).precision=jax.lax.Precision.HIGH
: utilise plusieurs passes MXU pour obtenir une précision plus élevée.precision=jax.lax.Precision.HIGHEST
: utilise encore plus de passes MXU pour obtenir une précision float32 complète.
JAX ajoute également le dtype bfloat16, que vous pouvez utiliser pour caster explicitement des tableaux dans bfloat16
, par exemple, jax.numpy.array(x, dtype=jax.numpy.bfloat16)
.
Exécuter JAX dans un Colab
Lorsque vous exécutez du code JAX dans un notebook Colab, Colab crée automatiquement un ancien nœud TPU. Les nœuds TPU ont une architecture différente. Pour en savoir plus, consultez la page Architecture du système.
Étape suivante
Pour en savoir plus sur Cloud TPU, consultez les pages suivantes :