Exécuter un calcul sur une VM Cloud TPU à l'aide de JAX

Ce document présente brièvement comment travailler avec JAX et Cloud TPU.

Avant de suivre ce guide de démarrage rapide, vous devez créer une instance Google Cloud Platform , installez la Google Cloud CLI et configurez la commande gcloud. Pour plus d'informations, consultez la page Configurer un compte et un projet Cloud TPU.

Installer Google Cloud CLI

La Google Cloud CLI contient des outils et des bibliothèques permettant d'interagir avec produits et services Google Cloud. Pour en savoir plus, consultez Installer la Google Cloud CLI

Configurer la commande gcloud

Exécutez les commandes suivantes pour configurer gcloud afin qu'il utilise votre projet et installer les composants nécessaires à la preview de la VM TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

Activer l'API Cloud TPU

  1. Activez l'API Cloud TPU à l'aide de la commande gcloud suivante dans Cloud Shell : Vous pouvez également l'activer depuis la console Google Cloud).

    $ gcloud services enable tpu.googleapis.com
    
  2. Exécutez la commande suivante pour créer une identité de service.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Créer une VM Cloud TPU avec gcloud

Avec les VM Cloud TPU, votre modèle et votre code s'exécutent directement sur l'hôte TPU. machine. Vous vous connectez directement à l'hôte TPU via SSH. Vous pouvez exécuter du code arbitraire, installer des packages, afficher les journaux et déboguer du code directement sur l'hôte TPU.

  1. Créez votre VM TPU en exécutant la commande suivante depuis une session Cloud Shell ou le terminal de votre ordinateur où la Google Cloud CLI est installé.

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base
    

    Champs obligatoires

    zone
    Zone dans laquelle vous souhaitez créer votre Cloud TPU.
    accelerator-type
    Le type d'accélérateur spécifie la version et la taille de la ressource Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez Versions de TPU.
    version
    Version logicielle de Cloud TPU. Pour tous les types de TPU, utilisez tpu-ubuntu2204-base

Se connecter à la VM Cloud TPU

Connectez-vous en SSH à votre VM TPU à l'aide de la commande suivante :

$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central2-b

Champs obligatoires

tpu_name
Nom de la VM TPU à laquelle vous vous connectez.
zone
Zone dans laquelle vous avez créé votre Cloud TPU.

Installer JAX sur votre VM Cloud TPU

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Vérifier le système

Vérifiez que JAX peut accéder au TPU et exécuter des opérations de base:

Démarrez l'interpréteur Python 3 :

(vm)$ python3
>>> import jax

Affichez le nombre de cœurs de TPU disponibles :

>>> jax.device_count()

Le nombre de cœurs de TPU s'affiche. Si vous utilisez un TPU v4, il doit s'agir 4 Si vous utilisez un TPU v2 ou v3, il doit s'agir de 8.

Effectuez un calcul simple :

>>> jax.numpy.add(1, 1)

Le résultat de l'ajout de Numpy s'affiche :

Sortie de la commande :

Array(2, dtype=int32, weak_type=true)

Quittez l'interpréteur Python :

>>> exit()

Exécuter du code JAX sur une VM TPU

Vous pouvez maintenant exécuter le code JAX de votre choix. Les exemples de type flax constituent un bon point de départ pour exécuter des modèles de ML standards dans JAX. Par exemple : pour entraîner un réseau convolutif MNIST de base:

  1. Installer les dépendances des exemples Flax

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
    
  2. Installer FLAX

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
    
  3. Exécutez le script d'entraînement FLAX MNIST.

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

Le script télécharge l'ensemble de données et commence l'entraînement. Le résultat du script doit se présente comme suit:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Effectuer un nettoyage

Lorsque vous avez fini d'utiliser votre VM TPU, procédez comme suit pour nettoyer vos ressources.

  1. Déconnectez-vous de l'instance Compute Engine, si vous ne l'avez pas déjà fait :

    (vm)$ exit
    
  2. Supprimez votre Cloud TPU.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central2-b
    
  3. Vérifiez que les ressources ont bien été supprimées en exécutant la commande suivante. Assurez-vous que votre TPU n'est plus répertorié. La suppression peut prendre plusieurs minutes.

    $ gcloud compute tpus tpu-vm list \
      --zone=us-central2-b
    

Remarques sur les performances

Voici quelques informations importantes, particulièrement pertinentes pour l'utilisation de TPU dans JAX.

Remplissage

L'une des causes les plus courantes de ralentissement des performances sur les TPU consiste à introduire une marge intérieure inattendue :

  • Les tableaux dans Cloud TPU sont tuilés. Cela implique de remplir l'un des une dimension à un multiple de 8, et une dimension différente à un multiple de 8 128.
  • L'unité de multiplication matricielle fonctionne mieux avec des paires de matrices volumineuses qui minimisent le besoin de remplissage.

bfloat16 dtype

Par défaut, la multiplication matricielle en JAX sur TPU utilise bfloat16 avec une accumulation de float32. Elle peut être contrôlée à l'aide de l'argument de précision pour les appels de fonction jax.numpy pertinents (matmul, point, einsum, etc.). En particulier :

  • precision=jax.lax.Precision.DEFAULT: utilise une plage bfloat16 mixte. précision (plus rapide)
  • precision=jax.lax.Precision.HIGH: utilise plusieurs cartes MXU pour d'obtenir une plus grande précision
  • precision=jax.lax.Precision.HIGHEST: utilise encore plus de cartes MXU pour obtenir une précision de type float32

JAX ajoute également le dtype bfloat16, que vous pouvez utiliser pour caster explicitement des tableaux en bfloat16, par exemple : jax.numpy.array(x, dtype=jax.numpy.bfloat16)

Exécuter JAX dans un Colab

Lorsque vous exécutez du code JAX dans un notebook Colab, Colab crée automatiquement un ancien nœud TPU. Les nœuds TPU ont une architecture différente. Pour en savoir plus, consultez la page Architecture du système.

Étape suivante

Pour en savoir plus sur Cloud TPU, consultez les pages suivantes :