Exécuter un calcul sur une VM Cloud TPU à l'aide de JAX
Ce document présente brièvement l'utilisation de JAX et de Cloud TPU.
Avant de commencer
Avant d'exécuter les commandes de ce document, vous devez créer un compte Google Cloud, installer Google Cloud CLI et configurer la commande gcloud
. Pour en savoir plus, consultez la section Configurer l'environnement Cloud TPU.
Créer une VM Cloud TPU à l'aide de gcloud
Définissez des variables d'environnement pour faciliter l'utilisation des commandes.
export PROJECT_ID=your-project export ACCELERATOR_TYPE=v5p-8 export ZONE=us-east5-a export RUNTIME_VERSION=v2-alpha-tpuv5 export TPU_NAME=your-tpu-name
Descriptions des variables d'environnement
PROJECT_ID
- Votre Google Cloud ID de projet
ACCELERATOR_TYPE
- Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles pour chaque version de TPU, consultez la section Versions de TPU.
ZONE
- Zone dans laquelle vous prévoyez de créer votre Cloud TPU.
RUNTIME_VERSION
- Version d'exécution de Cloud TPU. Pour en savoir plus, consultez la section Images de VM TPU.
TPU_NAME
- Nom attribué par l'utilisateur à votre Cloud TPU.
Créez votre VM TPU en exécutant la commande suivante à partir d'un environnement Cloud Shell ou du terminal d'ordinateur sur lequel la Google Cloud CLI est installée.
$ gcloud compute tpus tpu-vm create $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION
Se connecter à la VM Cloud TPU
Connectez-vous à votre VM TPU via SSH à l'aide de la commande suivante:
$ gcloud compute tpus tpu-vm ssh $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Installer JAX sur votre VM Cloud TPU
(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Vérifier le système
Vérifiez que JAX peut accéder au TPU et exécuter des opérations de base:
Démarrez l'interpréteur Python 3 :
(vm)$ python3
>>> import jax
Affichez le nombre de cœurs de TPU disponibles :
>>> jax.device_count()
Le nombre de cœurs de TPU est affiché. Le nombre de cœurs affiché dépend de la version de TPU que vous utilisez. Pour en savoir plus, consultez la section Versions de TPU.
Effectuez un calcul:
>>> jax.numpy.add(1, 1)
Le résultat de l'ajout de Numpy s'affiche :
Sortie de la commande :
Array(2, dtype=int32, weak_type=true)
Quittez l'interpréteur Python :
>>> exit()
Exécuter du code JAX sur une VM TPU
Vous pouvez maintenant exécuter n'importe quel code JAX. Les exemples de type flax constituent un bon point de départ pour exécuter des modèles de ML standards dans JAX. Par exemple, pour entraîner un réseau convolutif MNIST de base:
Installer les dépendances des exemples Flax
(vm)$ pip install --upgrade clu (vm)$ pip install tensorflow (vm)$ pip install tensorflow_datasets
Installer FLAX
(vm)$ git clone https://github.com/google/flax.git (vm)$ pip install --user flax
Exécutez le script d'entraînement FLAX MNIST.
(vm)$ cd flax/examples/mnist (vm)$ python3 main.py --workdir=/tmp/mnist \ --config=configs/default.py \ --config.learning_rate=0.05 \ --config.num_epochs=5
Le script télécharge l'ensemble de données et lance l'entraînement. Le résultat du script doit se présenter comme suit:
0214 18:00:50.660087 140369022753856 train.py:146] epoch: 1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88 I0214 18:00:52.015867 140369022753856 train.py:146] epoch: 2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72 I0214 18:00:53.377511 140369022753856 train.py:146] epoch: 3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04 I0214 18:00:54.727168 140369022753856 train.py:146] epoch: 4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15 I0214 18:00:56.082807 140369022753856 train.py:146] epoch: 5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18
Effectuer un nettoyage
Pour éviter que les ressources utilisées sur cette page ne soient facturées sur votre compte Google Cloud , procédez comme suit :
Lorsque vous avez fini d'utiliser votre VM TPU, procédez comme suit pour nettoyer vos ressources.
Déconnectez-vous de l'instance Compute Engine, si vous ne l'avez pas déjà fait :
(vm)$ exit
Supprimez votre Cloud TPU.
$ gcloud compute tpus tpu-vm delete $TPU_NAME \ --project=$PROJECT_ID \ --zone=$ZONE
Vérifiez que les ressources ont bien été supprimées en exécutant la commande suivante. Assurez-vous que votre TPU n'est plus répertorié. La suppression peut prendre plusieurs minutes.
$ gcloud compute tpus tpu-vm list \ --zone=$ZONE
Remarques sur les performances
Voici quelques informations importantes, particulièrement pertinentes pour l'utilisation de TPU dans JAX.
Remplissage
L'une des causes les plus courantes de ralentissement des performances sur les TPU consiste à introduire une marge intérieure inattendue :
- Les tableaux dans Cloud TPU sont tuilés. Cela implique de compléter l'une des dimensions jusqu'à un multiple de 8, et une autre jusqu'à un multiple de 128.
- L'unité de multiplication matricielle fonctionne mieux avec des paires de matrices volumineuses qui minimisent le besoin de remplissage.
bfloat16 dtype
Par défaut, la multiplication matricielle dans JAX sur les TPU utilise bfloat16 avec l'accumulation float32. Elle peut être contrôlée à l'aide de l'argument de précision pour les appels de fonction jax.numpy pertinents (matmul, point, einsum, etc.). En particulier :
precision=jax.lax.Precision.DEFAULT
: utilise la précision bfloat16 mixte (la plus rapide).precision=jax.lax.Precision.HIGH
: utilise plusieurs passes MXU pour obtenir une précision plus élevée.precision=jax.lax.Precision.HIGHEST
: utilise encore plus de passes MXU pour obtenir une précision float32 complète.
JAX ajoute également le dtype bfloat16, que vous pouvez utiliser pour caster explicitement des tableaux dans bfloat16
, par exemple, jax.numpy.array(x, dtype=jax.numpy.bfloat16)
.
Étape suivante
Pour en savoir plus sur Cloud TPU, consultez les pages suivantes :