Guide de démarrage rapide pour la VM Cloud TPU JAX

Ce document présente brièvement le fonctionnement de JAX et de Cloud TPU.

Connectez-vous à votre compte Google. Si vous n'avez pas encore de compte, créez-en un. Dans Google Cloud Console, sélectionnez ou créez un projet Cloud à partir de la page de sélection du projet. Assurez-vous que la facturation est activée pour votre projet.

Installer le SDK Google Cloud

Le SDK Google Cloud contient des outils et des bibliothèques permettant d'interagir avec les produits et services Google Cloud. Pour en savoir plus, consultez la page Installer le SDK Google Cloud.

Configurer la commande gcloud

Exécutez les commandes suivantes pour configurer gcloud afin qu'il utilise votre projet GCP et installe les composants nécessaires à l'aperçu de la VM TPU.

  $ gcloud config set account your-email-account
  $ gcloud config set project project-id

Activer l'API Cloud TPU

  1. Activez l'API Cloud TPU à l'aide de la commande gcloud suivante dans Cloud Shell. (Vous pouvez également l'activer depuis Google Cloud Console.)

    $ gcloud services enable tpu.googleapis.com
    
  2. Exécutez la commande suivante pour créer une identité de service.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Créer une VM Cloud TPU avec gcloud

Avec les VM Cloud TPU, le modèle et le code s'exécutent directement sur la machine hôte TPU. Vous vous connectez en SSH directement à l'hôte TPU. Vous pouvez exécuter du code arbitraire, installer des packages, afficher les journaux et déboguer le code directement sur l'hôte TPU.

  1. Créez votre VM TPU en exécutant la commande suivante à partir d'une console Cloud Shell GCP ou du terminal de votre ordinateur sur lequel le SDK Google Cloud est installé.

    (vm)$ gcloud alpha compute tpus tpu-vm create tpu-name \
    --zone europe-west4-a \
    --accelerator-type v3-8 \
    --version v2-alpha

    Champs obligatoires

    zone
    Zone dans laquelle vous souhaitez créer votre Cloud TPU.
    accelerator-type
    Type de Cloud TPU à créer.
    version
    La version d'exécution de Cloud TPU. Définissez ce paramètre sur "v2-alpha" lorsque vous utilisez JAX sur des appareils TPU uniques, des tranches de pod ou des pods entiers.

Se connecter à la VM Cloud TPU

Connectez-vous en SSH à la VM TPU à l'aide de la commande suivante:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a

Champs obligatoires

tpu_name
Nom de la VM TPU à laquelle vous vous connectez.
zone
Zone dans laquelle vous avez créé votre Cloud TPU.

Installer JAX sur votre VM Cloud TPU

(vm)$ pip3 install --upgrade jax jaxlib

Vérification du système

Vérifiez que tout est correctement installé en vérifiant que JAX voit les cœurs Cloud TPU et peut exécuter des opérations de base:

Démarrez l'interpréteur Python 3:

(vm)$ python3
>>> import jax

Affichez le nombre de cœurs TPU disponibles:

>>> jax.device_count()

Le nombre de cœurs TPU s'affiche. Il doit s'agir de 8.

Effectuez un calcul simple:

>>> jax.numpy.add(1, 1)

Le résultat de l'ajout de numpy s'affiche:

Résultat de la commande:

DeviceArray(2, dtype=int32)

Quittez l'interpréteur Python:

>>> exit()

Exécuter le code JAX sur une VM TPU

Vous pouvez maintenant exécuter le code JAX de votre choix. Les exemples flax sont un bon point de départ pour commencer à exécuter des modèles de ML standards dans JAX. Par exemple, pour entraîner un réseau convolutionnel de base:

  1. Installer des ensembles de données TensorFlow

    (vm)$ pip install --upgrade clu
    
  2. Installez FLAX.

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user -e flax
    
  3. Exécuter le script d'entraînement MNISTXFL

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

    Le résultat du script doit se présenter comme suit:

    I0513 21:09:35.448946 140431261813824 train.py:125] train epoch: 1, loss: 0.2312, accuracy: 93.00
    I0513 21:09:36.402860 140431261813824 train.py:176] eval epoch: 1, loss: 0.0563, accuracy: 98.05
    I0513 21:09:37.321380

Nettoyer

Lorsque vous avez terminé avec votre VM TPU, procédez comme suit pour nettoyer vos ressources.

  1. Déconnectez-vous de l'instance Compute Engine, si vous ne l'avez pas déjà fait :

    (vm)$ exit
    
  2. Supprimez votre Cloud TPU.

    $ gcloud alpha compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. Exécutez la commande suivante pour vérifier que les ressources ont bien été supprimées. Assurez-vous que votre TPU n'est plus répertorié. La suppression peut prendre plusieurs minutes.

Remarques sur les performances

Voici quelques détails importants, notamment pour utiliser des TPU dans AJAX.

Remplissage

L'une des causes les plus courantes d'un ralentissement des performances sur les TPU est l'introduction d'un remplissage excessif:

  • Les tableaux dans Cloud TPU sont tuilés. Cela implique de compléter l'une des dimensions jusqu'à un multiple de 8, et une autre jusqu'à un multiple de 128.
  • L'unité de multiplication matricielle offre de meilleures performances avec des paires de matrices volumineuses qui minimisent le besoin de remplissage.

Type de fichier bfloat16

Par défaut, la multiplication matricielle dans JAX sur TPU utilise bfloat16 avec l'accumulation de nombre à virgule flottante. Cette méthode peut être contrôlée à l'aide de l'argument de précision sur les appels de fonction jax.numpy pertinents (matmul, point, einsum, etc.). En particulier :

  • precision=jax.lax.Precision.DEFAULT: utilise la précision mixte bfloat16 (la plus rapide)
  • precision=jax.lax.Precision.HIGH: utilise plusieurs passes MXU pour atteindre une précision plus élevée.
  • precision=jax.lax.Precision.HIGHEST: utilise encore plus de passes MXU pour atteindre une précision complète de float32.

JAX ajoute également le dtype bfloat16, que vous pouvez utiliser pour convertir explicitement des tableaux en bfloat16, par exemple : jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Exécuter JAX dans Colab

Lorsque vous exécutez du code JAX dans un notebook Colab, Colab crée automatiquement un ancien nœud TPU. Les nœuds TPU ont une architecture différente. Pour plus d'informations, consultez la page Architecture du système.