Entraînement sur un TPU à hôte unique à l'aide de Pax


Ce document présente brièvement l'utilisation de Pax sur un TPU à hôte unique (v2-8, v3-8, v4-8).

Pax est un framework permettant de configurer et d'exécuter des expériences de machine learning sur JAX. Pax se concentre sur la simplification du ML à grande échelle en partageant des composants d'infrastructure avec les frameworks de ML existants et en utilisant la bibliothèque de modélisation Praxis pour la modularité.

Objectifs

  • Configurer des ressources TPU pour l'entraînement
  • Installer Pax sur un TPU à hôte unique
  • Entraîner un modèle SPMD basé sur un Transformer à l'aide de Pax

Avant de commencer

Exécutez les commandes suivantes pour configurer gcloud afin d'utiliser votre projet Cloud TPU et d'installer les composants nécessaires pour entraîner un modèle exécutant Pax sur un TPU à hôte unique.

Installer Google Cloud CLI

La CLI Google Cloud contient des outils et des bibliothèques permettant d'interagir avec les produits et services du Google Cloud CLI. Si vous ne l'avez pas déjà installé, faites-le maintenant en suivant les instructions de la section Installer la Google Cloud CLI.

Configurer la commande gcloud

(Exécutez gcloud auth list pour afficher vos comptes disponibles).

$ gcloud config set account account

$ gcloud config set project project-id

Activer l'API Cloud TPU

Activez l'API Cloud TPU à l'aide de la commande gcloud suivante dans Cloud Shell. Vous pouvez également l'activer à partir de Google Cloud Console.

$ gcloud services enable tpu.googleapis.com

Exécutez la commande suivante pour créer une identité de service (un compte de service).

$ gcloud beta services identity create --service tpu.googleapis.com

Créer une VM TPU

Avec les VM Cloud TPU, votre modèle et votre code s'exécutent directement sur la VM TPU. Vous vous connectez directement à la VM TPU via SSH. Vous pouvez exécuter du code arbitraire, installer des packages, afficher les journaux et déboguer le code directement sur la VM TPU.

Créez votre VM TPU en exécutant la commande suivante à partir d'un environnement Cloud Shell ou du terminal d'ordinateur sur lequel la Google Cloud CLI est installée.

Définissez zone en fonction de la disponibilité dans votre contrat. Consultez la section Régions et zones TPU si nécessaire.

Définissez la variable accelerator-type sur v2-8, v3-8 ou v4-8.

Définissez la variable version sur tpu-vm-base pour les versions TPU v2 et v3, ou sur tpu-vm-v4-base pour les TPU v4.

$ gcloud compute tpus tpu-vm create tpu-name \
--zone zone \
--accelerator-type accelerator-type \
--version version

Se connecter à votre VM Google Cloud TPU

Connectez-vous en SSH à votre VM TPU à l'aide de la commande suivante :

$ gcloud compute tpus tpu-vm ssh tpu-name --zone zone

Lorsque vous êtes connecté à la VM, votre invite d'interface système passe de username@projectname à username@vm-name:

Installer Pax sur la VM Google Cloud TPU

Installez Pax, JAX et libtpu sur votre VM TPU à l'aide des commandes suivantes:

(vm)$ python3 -m pip install -U pip \
python3 -m pip install paxml jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Vérifier le système

Vérifiez que tout est correctement installé en vérifiant que JAX voit les cœurs TPU:

(vm)$ python3 -c "import jax; print(jax.device_count())"

Le nombre de cœurs de TPU affiché doit être 8 si vous utilisez un TPU v2-8 ou v3-8, ou 4 si vous utilisez un TPU v4-8.

Exécuter du code Pax sur une VM TPU

Vous pouvez maintenant exécuter n'importe quel code Pax. Les exemples de lm_cloud constituent un excellent point de départ pour exécuter des modèles dans Pax. Par exemple, les commandes suivantes entraînent un modèle de langage SPMD basé sur un transformateur avec 2 milliards de paramètres sur des données synthétiques.

Les commandes suivantes affichent le résultat de l'entraînement pour un modèle de langage SPMD. L'entraînement dure 300 étapes en environ 20 minutes.

(vm)$ python3 .local/lib/python3.10/site-packages/paxml/main.py  --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps --job_log_dir=job_log_dir

Sur la tranche v4-8, le résultat doit inclure les éléments suivants:

Pertes et temps d'étape

Tensor de résumé à l'étape=step_# loss = loss
Tensor de résumé à l'étape=step_# Pas par seconde x

Effectuer un nettoyage

Pour éviter que les ressources utilisées lors de ce tutoriel soient facturées sur votre compte Google Cloud, supprimez le projet contenant les ressources, ou conservez le projet et supprimez les ressources individuelles.

Lorsque vous avez fini d'utiliser votre VM TPU, procédez comme suit pour nettoyer vos ressources.

Déconnectez-vous de l'instance Compute Engine, si vous ne l'avez pas déjà fait:

(vm)$ exit

Supprimez votre Cloud TPU.

$ gcloud compute tpus tpu-vm delete tpu-name  --zone zone

Étape suivante

Pour en savoir plus sur Cloud TPU, consultez les pages suivantes :