Guide de démarrage rapide pour la VM Cloud TPU avec JAX

Ce document fournit une brève introduction à l'utilisation de JAX et Cloud TPU.

Avant de suivre ce guide de démarrage rapide, vous devez créer un compte Google Cloud Platform et installer le SDK Google Cloud Platform. et configurer la commande gcloud Pour en savoir plus, consultez la section Configurer un compte et un projet Cloud TPU.

Installer le SDK Google Cloud

Le SDK Google Cloud contient des outils et des bibliothèques permettant d'interagir avec les produits et services Google Cloud. Pour en savoir plus, consultez la page Installer le SDK Google Cloud.

Configurer la commande gcloud

Exécutez les commandes suivantes pour configurer gcloud afin d'utiliser votre projet GCP et d'installer les composants nécessaires à l'aperçu de la VM TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project project-id

Activer l'API Cloud TPU

  1. Activez l'API Cloud TPU à l'aide de la commande gcloud suivante dans Cloud Shell. (Vous pouvez également l'activer à partir de Google Cloud Console.)

    $ gcloud services enable tpu.googleapis.com
    
  2. Exécutez la commande suivante pour créer une identité de service.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Créer une VM Cloud TPU avec gcloud

Avec les VM Cloud TPU, votre modèle et votre code s'exécutent directement sur la machine hôte TPU. Vous vous connectez directement à l'hôte TPU via SSH. Vous pouvez exécuter du code arbitraire, installer des packages, afficher les journaux et déboguer le code directement sur l'hôte TPU.

  1. Créez votre VM TPU en exécutant la commande suivante à partir d'un environnement GCP Cloud Shell ou du terminal d'ordinateur sur lequel le SDK Google Cloud est installé.

    (vm)$ gcloud alpha compute tpus tpu-vm create tpu-name \
    --zone europe-west4-a \
    --accelerator-type v3-8 \
    --version v2-alpha

    Champs obligatoires

    zone
    Zone dans laquelle vous souhaitez créer votre Cloud TPU.
    accelerator-type
    Type du Cloud TPU à créer.
    version
    Version d'exécution de Cloud TPU. Définissez cette valeur sur "v2-alpha" lorsque vous utilisez JAX sur des appareils TPU uniques, des tranches de pod ou des pods entiers.

Se connecter à la VM Cloud TPU

Connectez-vous en SSH à votre VM TPU à l'aide de la commande suivante:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a

Champs obligatoires

tpu_name
Nom de la VM TPU à laquelle vous vous connectez.
zone
Zone dans laquelle vous avez créé votre Cloud TPU.

Installer JAX sur votre VM Cloud TPU

(vm)$ pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Vérification du système

Vérifiez que tout est correctement installé en vérifiant que JAX voit les cœurs Cloud TPU et qu'il peut exécuter des opérations de base:

Démarrez l'interpréteur Python 3:

(vm)$ python3
>>> import jax

Affichez le nombre de cœurs de TPU disponibles:

>>> jax.device_count()

Le nombre de cœurs de TPU affiché doit être 8.

Effectuez un calcul simple:

>>> jax.numpy.add(1, 1)

Le résultat de l'ajout de numpy s'affiche:

Résultat de la commande:

DeviceArray(2, dtype=int32)

Quittez l'interpréteur Python:

>>> exit()

Exécuter du code JAX sur une VM TPU

Vous pouvez à présent exécuter n'importe quel code JAX, s'il vous plaît. Les exemples de lin sont un excellent point de départ pour commencer à exécuter des modèles de ML standards dans JAX. Par exemple, pour entraîner un réseau convolutif MNIST de base, procédez comme suit:

  1. Installer des ensembles de données TensorFlow

    (vm)$ pip install --upgrade clu
    
  2. Installez FLAX.

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user -e flax
    
  3. Exécuter le script d'entraînement FLAX MNIST

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

    Le résultat du script doit se présenter comme suit:

    I0513 21:09:35.448946 140431261813824 train.py:125] train epoch: 1, loss: 0.2312, accuracy: 93.00
    I0513 21:09:36.402860 140431261813824 train.py:176] eval epoch: 1, loss: 0.0563, accuracy: 98.05
    I0513 21:09:37.321380

Nettoyer

Lorsque vous avez terminé avec votre VM TPU, procédez comme suit pour nettoyer vos ressources.

  1. Déconnectez-vous de l'instance Compute Engine, si vous ne l'avez pas déjà fait :

    (vm)$ exit
    
  2. Supprimez votre Cloud TPU.

    $ gcloud alpha compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. Vérifiez que les ressources ont bien été supprimées en exécutant la commande suivante : Assurez-vous que votre TPU n'est plus répertorié. La suppression peut prendre plusieurs minutes.

Remarques sur les performances

Voici quelques informations importantes, particulièrement pertinentes pour l'utilisation de TPU dans JAX.

Remplissage

L'une des causes les plus courantes de lenteur sur les TPU consiste à ajouter une marge intérieure accidentelle:

  • Les tableaux dans Cloud TPU sont tuilés. Cela implique de remplir l'une des dimensions avec un multiple de 8, et une dimension différente avec un multiple de 128.
  • L'unité de multiplication matricielle fonctionne mieux avec des paires de matrices volumineuses qui minimisent le besoin de remplissage.

bfloat16 dtype

Par défaut, la multiplication matricielle dans JAX sur les TPU utilise bfloat16 avec l'accumulation float32. Il peut être contrôlé à l'aide de l'argument de précision sur les appels de fonction jax.numpy pertinents (matmul, point, einsum, etc.). En particulier :

  • precision=jax.lax.Precision.DEFAULT: utilise la précision bfloat16 mixte (la plus rapide)
  • precision=jax.lax.Precision.HIGH: utilise plusieurs MXU pour obtenir une précision plus élevée.
  • precision=jax.lax.Precision.HIGHEST: utilise encore plus de passes MXU pour obtenir une précision float32 complète.

JAX ajoute également le dtype bfloat16, que vous pouvez utiliser pour caster explicitement des tableaux dans bfloat16. Par exemple : jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Exécuter JAX dans un Colab

Lorsque vous exécutez du code JAX dans un notebook Colab, Colab crée automatiquement un ancien nœud TPU. Les nœuds TPU ont une architecture différente. Pour en savoir plus, consultez la section Architecture système.