Effectuer un calcul sur une VM Cloud TPU à l'aide de JAX

Ce document présente brièvement le travail avec JAX et Cloud TPU.

Avant de suivre ce guide de démarrage rapide, vous devez créer un compte Google Cloud Platform, installer la Google Cloud CLI et configurer la commande gcloud. Pour plus d'informations, consultez la page Configurer un compte et un projet Cloud TPU.

Installer Google Cloud CLI

La Google Cloud CLI contient des outils et des bibliothèques permettant d'interagir avec les produits et services Google Cloud. Pour en savoir plus, consultez la page Installer la Google Cloud CLI.

Configurer la commande gcloud

Exécutez les commandes suivantes pour configurer gcloud afin qu'il utilise votre projet Google Cloud et installer les composants nécessaires à la prévisualisation de la VM TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

Activer l'API Cloud TPU

  1. Activez l'API Cloud TPU à l'aide de la commande gcloud suivante dans Cloud Shell. (Vous pouvez également l'activer à partir de la console Google Cloud.)

    $ gcloud services enable tpu.googleapis.com
    
  2. Exécutez la commande suivante pour créer une identité de service.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Créer une VM Cloud TPU avec gcloud

Avec les VM Cloud TPU, votre modèle et votre code s'exécutent directement sur la machine hôte TPU. Vous vous connectez directement à l'hôte TPU via SSH. Vous pouvez exécuter du code arbitraire, installer des packages, afficher les journaux et déboguer du code directement sur l'hôte TPU.

  1. Créez votre VM TPU en exécutant la commande suivante à partir de Cloud Shell ou du terminal de votre ordinateur sur lequel la Google Cloud CLI est installée.

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base
    

    Champs obligatoires

    zone
    Zone dans laquelle vous prévoyez de créer la ressource Cloud TPU.
    accelerator-type
    Le type d'accélérateur spécifie la version et la taille de la ressource Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez la section Versions de TPU.
    version
    Version du logiciel Cloud TPU. Pour tous les types de TPU, utilisez tpu-ubuntu2204-base.

Se connecter à la VM Cloud TPU

Connectez-vous en SSH à votre VM TPU à l'aide de la commande suivante :

$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central2-b

Champs obligatoires

tpu_name
Nom de la VM TPU à laquelle vous vous connectez.
zone
Zone dans laquelle vous avez créé votre Cloud TPU.

Installer JAX sur votre VM Cloud TPU

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Vérifier le système

Vérifiez que JAX peut accéder au TPU et exécuter des opérations de base:

Démarrez l'interpréteur Python 3 :

(vm)$ python3
>>> import jax

Affichez le nombre de cœurs de TPU disponibles :

>>> jax.device_count()

Le nombre de cœurs de TPU s'affiche. Si vous utilisez un TPU v4, il doit s'agir de 4. Si vous utilisez un TPU v2 ou v3, ce champ doit être 8.

Effectuez un calcul simple :

>>> jax.numpy.add(1, 1)

Le résultat de l'ajout de Numpy s'affiche :

Sortie de la commande :

Array(2, dtype=int32, weak_type=true)

Quittez l'interpréteur Python :

>>> exit()

Exécuter du code JAX sur une VM TPU

Vous pouvez maintenant exécuter le code JAX de votre choix. Les exemples de type flax constituent un bon point de départ pour exécuter des modèles de ML standards dans JAX. Par exemple, pour entraîner un réseau convolutif MNIST de base:

  1. Installer des dépendances d'exemples Flax

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
    
  2. Installer FLAX

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
    
  3. Exécutez le script d'entraînement FLAX MNIST.

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

Le script télécharge l'ensemble de données et lance l'entraînement. Le résultat du script doit se présenter comme suit:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Effectuer un nettoyage

Lorsque vous avez fini d'utiliser votre VM TPU, procédez comme suit pour nettoyer vos ressources.

  1. Déconnectez-vous de l'instance Compute Engine, si vous ne l'avez pas déjà fait :

    (vm)$ exit
    
  2. Supprimez votre Cloud TPU.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central2-b
    
  3. Vérifiez que les ressources ont bien été supprimées en exécutant la commande suivante. Assurez-vous que votre TPU n'est plus répertorié. La suppression peut prendre plusieurs minutes.

    $ gcloud compute tpus tpu-vm list \
      --zone=us-central2-b
    

Remarques sur les performances

Voici quelques informations importantes, particulièrement pertinentes pour l'utilisation de TPU dans JAX.

Remplissage

L'une des causes les plus courantes de ralentissement des performances sur les TPU consiste à introduire une marge intérieure inattendue :

  • Les tableaux dans Cloud TPU sont tuilés. Cela implique de remplir une dimension jusqu'à un multiple de 8 et une autre dimension jusqu'à un multiple de 128.
  • L'unité de multiplication matricielle fonctionne mieux avec des paires de matrices volumineuses qui minimisent le besoin de remplissage.

bfloat16 dtype

Par défaut, la multiplication matricielle dans JAX sur TPU utilise bfloat16 avec l'accumulation float32. Elle peut être contrôlée à l'aide de l'argument de précision pour les appels de fonction jax.numpy pertinents (matmul, point, einsum, etc.). En particulier :

  • precision=jax.lax.Precision.DEFAULT: utilise la précision mixte bfloat16 (la plus rapide).
  • precision=jax.lax.Precision.HIGH: utilise plusieurs cartes MXU pour obtenir une précision plus élevée
  • precision=jax.lax.Precision.HIGHEST: utilise encore plus de passes MXU pour obtenir une précision float32 complète.

JAX ajoute également le dtype bfloat16, que vous pouvez utiliser pour caster explicitement des tableaux sur bfloat16 (par exemple, jax.numpy.array(x, dtype=jax.numpy.bfloat16)).

Exécuter JAX dans un Colab

Lorsque vous exécutez du code JAX dans un notebook Colab, Colab crée automatiquement un ancien nœud TPU. Les nœuds TPU ont une architecture différente. Pour en savoir plus, consultez la page Architecture du système.

Étapes suivantes

Pour en savoir plus sur Cloud TPU, consultez les pages suivantes :