Point de contrôle automatique Cloud TPU [version Preview publique]
Présentation
Par le passé, lorsqu'une VM TPU nécessite une maintenance, la procédure est lancée immédiatement, sans laisser aux utilisateurs le temps d'effectuer des actions préservant la progression, telles que l'enregistrement d'un point de contrôle. Ce processus est illustré dans la figure 1(a).
Fig. 1. Illustration de la fonctionnalité de point de contrôle automatique : (a) Sans point de contrôle automatique, la progression de l'entraînement du dernier point de contrôle est perdue en cas d'événement de maintenance à venir. (b) Avec Autocheckpoint, la progression de l'entraînement depuis le dernier point de contrôle peut être conservée en cas d'événement de maintenance à venir.
Vous pouvez utiliser le point de contrôle automatique (figure 1(b)) pour conserver la progression de l'entraînement en configurant votre code pour enregistrer un point de contrôle non planifié lorsqu'un événement de maintenance se produit. Lorsqu'un événement de maintenance se produit, la progression depuis le dernier point de contrôle est automatiquement enregistrée. Cette fonctionnalité est compatible avec les tranches uniques et les multitranches.
La fonctionnalité de point de contrôle automatique fonctionne avec des frameworks pouvant capturer un signal SIGTERM, puis enregistrer un point de contrôle. Les frameworks compatibles incluent MaxText, Pax et JAX avec Orbax. Nous annoncerons la compatibilité d'autres frameworks dès qu'ils seront disponibles.
Pour le moment, seuls les TPU (v2-v4 et v5e) créés via l'API Cloud TPU peuvent utiliser cette fonctionnalité. La compatibilité des TPU dans GKE sera annoncée dès qu'elle sera disponible.
Utiliser le point de contrôle automatique
La fonctionnalité de point de contrôle automatique est désactivée par défaut. Lorsque vous créez un TPU ou une ressource en file d'attente, vous pouvez l'activer en ajoutant l'option --autocheckpoint-enabled
lors du provisionnement du TPU.
Lorsque cette fonctionnalité est activée, Cloud TPU effectue les étapes suivantes une fois qu'il reçoit une notification concernant un événement de maintenance:
- Capturer le SIGTERM envoyé au processus à l'aide de l'appareil TPU,
- Attend que le processus se termine ou que 5 minutes se soient écoulés, selon la première échéance atteinte, et effectue la maintenance sur les tranches concernées.
Notez que l'infrastructure utilisée par Autocheckpoint est indépendante du framework de ML. Tout framework de ML peut prendre en charge Autocheckpoint s'il peut capturer le signal SIGTERM et lancer un processus de création de points de contrôle.
Dans le code de l'application, vous devez activer les fonctionnalités de point de contrôle automatique fournies par le framework de ML. Dans Pax, par exemple, cela signifie activer les indicateurs de ligne de commande lors du lancement de l'entraînement (consultez le guide de démarrage rapide avec Pax pour l'autocheckpoint). En arrière-plan, les frameworks enregistrent un point de contrôle non planifié à la réception d'un SIGTERM, et la VM TPU concernée fait l'objet d'une maintenance lorsque le TPU n'est plus utilisé.
Guide de démarrage rapide: point de contrôle automatique avec MaxText
MaxText est un LLM hautes performances, arbitrairement évolutif, Open Source et éprouvé, écrit en Python/JAX pur et ciblant les Cloud TPU. MaxText contient toute la configuration nécessaire pour utiliser la fonctionnalité de point de contrôle automatique.
Le fichier README de MaxText décrit deux manières d'exécuter MaxText à grande échelle:
- Utilisation de
multihost_runner.py
(recommandée pour les tests) - Utilisation de
multihost_job.job
(recommandé pour la production)
Si vous utilisez multihost_runner.py
, la seule modification requise consiste à définir l'option autocheckpoint-enabled
lors du provisionnement de la ressource en file d'attente. Si vous utilisez multihost_job.py
, la seule modification requise consiste à spécifier l'option de ligne de commande ENABLE_AUTOCHECKPOINT=true
lors du lancement de la tâche.
Guide de démarrage rapide: point de contrôle automatique avec Pax sur des tranches uniques
Dans cette section, nous fournissons un exemple de configuration et d'utilisation du point de contrôle automatique avec Pax sur une tranche unique. Avec la configuration appropriée:
- Un point de contrôle sera enregistré lors d'un événement de maintenance.
- Cloud TPU effectuera la maintenance sur la ou les VM TPU concernées une fois le point de contrôle enregistré.
- Une fois la maintenance de Cloud TPU terminée, vous pouvez utiliser la VM TPU comme d'habitude.
Utilisez l'option
autocheckpoint-enabled
lors de la création de la VM TPU ou de la ressource en file d'attente.Exemple :
PROJECT=your-gcp-project-name ZONE=zone-you-want-to-use NODE_ID=your-node-id ACCELERATOR_TYPE=your-accelerator-type gcloud config set project $PROJECT gcloud config set compute/zone $ZONE
gcloud alpha compute tpus tpu-vm create $NODE_ID \ --accelerator-type $ACCELERATOR_TYPE \ --version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Installer Pax sur une seule tranche
La fonctionnalité Autocheckpoint fonctionne sur les versions de Pax 1.1.0 ou ultérieures. Sur les VM TPU, installez
jax[tpu]
et la dernière version depaxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Lancer l'entraînement avec la configuration appropriée
L'exemple suivant montre comment configurer le modèle
LmCloudSpmd2B
pour enregistrer les points de contrôle déclenchés par Autocheckpoint dans un bucket Google Cloud Storage:JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Notez les deux options transmises à la commande:
jax_fully_async_checkpoint
: lorsque cet indicateur est activé,orbax.checkpoint.AsyncCheckpointer
est utilisé. La classeAsyncCheckpointer
enregistre automatiquement un point de contrôle lorsque le script d'entraînement reçoit un signal SIGTERM.exit_after_ondemand_checkpoint
: lorsque cette option est activée, le processus TPU se ferme une fois le point de contrôle automatique enregistré, ce qui déclenche l'exécution immédiate de la maintenance. Si vous n'utilisez pas cet indicateur, l'entraînement se poursuivra après l'enregistrement du point de contrôle et Cloud TPU attendra le délai avant expiration (cinq minutes) avant d'effectuer la maintenance requise.
Guide de démarrage rapide: point de contrôle automatique avec Pax sur plusieurs tranches
Le point de contrôle automatique fonctionne non seulement pour les tranches uniques, mais également pour les tranches multiples. Cette section décrit la procédure à suivre pour utiliser le point de contrôle automatique avec des tranches d'âge multiples.
Spécifiez le point de contrôle automatique lors de la création des ressources en file d'attente.
Un environnement à segments multiples ne peut être provisionné que via une requête de ressource en file d'attente. Comme dans le cas d'une tranche unique, utilisez l'option
autocheckpoint-enabled
dans l'appel pour créer une ressource en file d'attente.QR_ID=your-qr-id NODE_COUNT=your-node-count ACCELERATOR_TYPE=your-accelerator-type gcloud compute tpus queued-resources create $QR_ID \ --node-count $NODE_COUNT \ --accelerator-type $ACCELERATOR_TYPE \ --runtime-version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Reportez-vous au guide de l'utilisateur de multislice pour plus de détails sur toutes les options disponibles. Une fois la requête de ressource en file d'attente créée et à l'état
ACTIVE
, suivez les étapes suivantes pour exécuter Pax avec Autocheckpoint.Installez Pax sur toutes les VM de l'environnement multislice.
Sur les VM TPU, installez
jax[tpu]
et la dernière version depaxml
sur toutes les VM TPU de votre environnement multitranche:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Lancer l'entraînement avec la configuration appropriée
Cet exemple montre comment configurer le modèle
LmCloudSpmd2B
pour le point de contrôle automatique lors de l'entraînement dans un environnement à plusieurs tranches. Avant d'exécuter le script d'entraînement, définissez DCN_MESH_SHAPE sur [2, 1, 1], comme indiqué dans le code suivant:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 4, 1] DCN_MESH_SHAPE = [2, 1, 1]
Lors du lancement de l'entraînement, en plus des indicateurs de ligne de commande décrits dans le cas de la tranche unique, trois autres sont nécessaires:
num_hosts
: nombre total d'hôtes Dans ce cas, il s'agit de 2.host_index
: index de l'hôte qui lance l'entraînement. Elle varie de 0 àN-1
, oùN
correspond au nombre total d'hôtes.server_addr
: adresse IP du nœud de calcul 0 du nœud 0, avec un port inutilisé (par exemple, 8476) Pour obtenir ces informations, utilisezhostname -i
sur le nœud de calcul 0 du nœud 0.
Point de contrôle automatique avec Orbax
La fonctionnalité de point de contrôle automatique ne se limite pas à MaxText ni à Pax. Tout framework capable de capturer le signal SIGTERM et de lancer un processus de création de points de contrôle fonctionne avec l'infrastructure fournie par Autocheckpoint. Orbax, un espace de noms qui fournit des bibliothèques d'utilitaires courantes pour les utilisateurs JAX, propose ces fonctionnalités.
Comme expliqué dans la documentation Orbax, ces fonctionnalités sont activées par défaut pour les utilisateurs de orbax.checkpoint.CheckpointManager
. La méthode save
appelée après chaque étape vérifie automatiquement si un événement de maintenance est imminent et, le cas échéant, enregistre un point de contrôle même si le numéro d'étape n'est pas un multiple de save_interval_steps
.
La documentation GitHub montre également comment quitter l'entraînement après avoir enregistré un point de contrôle automatique, avec une modification du code utilisateur.