Cloud TPU Autocheckpoint [Preview public]
Présentation
Historiquement, lorsqu'une VM TPU nécessite une maintenance, la procédure est lancée immédiatement, sans laisser aux utilisateurs le temps d'effectuer des actions de conservation de la progression, telles que l'enregistrement d'un point de contrôle. C'est ce que montre la figure 1(a).
Fig. 1. Illustration de la fonctionnalité de point de contrôle automatique : (a) Sans point de contrôle automatique, la progression de l'entraînement à partir du dernier point de contrôle est perdue en cas d'événement de maintenance à venir. (b) Avec le point de contrôle automatique, la progression de l'entraînement depuis le dernier point de contrôle peut être conservée en cas d'événement de maintenance à venir.
Vous pouvez utiliser le point de contrôle automatique (figure 1(b)) pour préserver la progression de l'entraînement en la configuration de votre code pour enregistrer un point de contrôle non planifié se produit. Lorsqu'un événement de maintenance se produit, la progression depuis le dernier point de contrôle est automatiquement enregistrée. La fonctionnalité fonctionne sur les deux segments uniques et des multitranches.
La fonctionnalité de point de contrôle automatique fonctionne avec des frameworks capables de capturer SIGTERM et enregistrer ensuite un point de contrôle. Les cadres pris en charge incluent MaxText, Pax, et JAX avec Orbax La prise en charge de frameworks supplémentaires sera annoncée à mesure qu'ils seront disponibles.
Pour le moment, seuls les TPU (versions v2 à v4 et v5e) créés via l'API Cloud TPU peuvent utiliser cette fonctionnalité. Les TPU seront pris en charge dans GKE sera annoncée dès qu'elle sera disponible.
Utiliser le point de contrôle automatique
La fonctionnalité de point de contrôle automatique est désactivée par défaut. Lorsque vous créez un
TPU ou une ressource en file d'attente
vous pouvez l'activer en ajoutant l'option --autocheckpoint-enabled
lors du provisionnement
le TPU.
Une fois cette fonctionnalité activée, Cloud TPU
effectue les étapes suivantes après avoir reçu la notification
événement de maintenance:
- Capturer le SIGTERM envoyé au processus à l'aide de l'appareil TPU,
- Attend que le processus se termine ou que 5 minutes se soient écoulées, selon le cas qui intervient en premier et effectue la maintenance sur les tranches impactées.
Notez que l'infrastructure utilisée par Autocheckpoint est indépendante du framework de ML. Tout framework de ML peut prendre en charge Autocheckpoint à condition qu'il puisse capturer le signal SIGTERM et lancer un processus de création de points de contrôle.
Dans le code de l'application, vous devez activer le point de contrôle automatique fournies par le framework de ML. Dans Pax, par exemple, cela signifie activer les indicateurs de ligne de commande lors du lancement de l'entraînement (voir le démarrage rapide de l'autocheckpoint avec Pax). En coulisse, les frameworks épargnent point de contrôle non planifié lorsqu'un SIGTERM est reçu et la VM TPU concernée fait l'objet d'une maintenance lorsque le TPU n'est plus en cours d'utilisation.
Guide de démarrage rapide: point de contrôle automatique avec MaxText
MaxText est un type de conversion LLM arbitrairement évolutif, Open Source et éprouvé, écrit en Python/JAX pur ciblant les Cloud TPU". MaxText contient toute la configuration nécessaire pour utiliser la fonctionnalité de contrôle automatique.
Le fichier README de MaxText décrit deux façons d'exécuter MaxText à grande échelle :
- Utilisation de
multihost_runner.py
, recommandé pour les tests - Utilisation de
multihost_job.job
, recommandée pour la production
Lorsque vous utilisez multihost_runner.py
, la seule modification requise
consiste à définir l'option autocheckpoint-enabled
lors du provisionnement
la ressource en file d'attente. Lorsque vous utilisez
multihost_job.py
, la seule modification requise consiste à spécifier le
l'option de ligne de commande ENABLE_AUTOCHECKPOINT=true
lors du lancement de la tâche.
Démarrage rapide : point de contrôle automatique avec Pax sur des tranches uniques
Dans cette section, nous fournissons un exemple de configuration et d'utilisation du point de contrôle automatique. avec Pax sur une seule tranche. Avec la configuration appropriée :
- Un point de contrôle est enregistré lorsqu'un événement de maintenance se produit.
- Cloud TPU effectuera la maintenance des VM TPU concernées une fois le point de contrôle enregistré.
- Lorsque la maintenance de Cloud TPU est terminée, vous pouvez utiliser la VM TPU comme d'habitude.
Utilisez l'option
autocheckpoint-enabled
lors de la création de la VM TPU ou ressource en file d'attente.Exemple :
PROJECT=your-gcp-project-name ZONE=zone-you-want-to-use NODE_ID=your-node-id ACCELERATOR_TYPE=your-accelerator-type gcloud config set project $PROJECT gcloud config set compute/zone $ZONE
gcloud alpha compute tpus tpu-vm create $NODE_ID \ --accelerator-type $ACCELERATOR_TYPE \ --version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Installer Pax sur une seule tranche
La fonctionnalité de point de contrôle automatique fonctionne avec les versions Pax 1.1.0 et ultérieures. Sur les VM TPU, installez
jax[tpu]
et la dernière version depaxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Lancer l'entraînement avec la configuration appropriée
L'exemple suivant montre comment configurer le modèle
LmCloudSpmd2B
pour enregistrer les points de contrôle déclenchés par Autocheckpoint dans un bucket Google Cloud Storage :JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Notez les deux options transmises à la commande :
jax_fully_async_checkpoint
: Lorsque cet indicateur est activé,orbax.checkpoint.AsyncCheckpointer
est utilisé. La classeAsyncCheckpointer
enregistre automatiquement un point de contrôle lorsque le script d'entraînement reçoit un signal SIGTERM.exit_after_ondemand_checkpoint
: Lorsque cet indicateur est activé, le processus TPU s'arrête après le Le point de contrôle automatique a bien été enregistré, ce qui déclenche la maintenance doit être effectuée immédiatement. Si vous n'utilisez pas l'entraînement se poursuivra une fois le point de contrôle enregistré Cloud TPU attendra le délai avant expiration (5 minutes) avant d'effectuer la maintenance requise.
Guide de démarrage rapide: point de contrôle automatique avec Pax sur plusieurs tranches
Le point de contrôle automatique fonctionne non seulement pour les tranches pour Multislice (Tranches multiples). Cette section détaille la procédure à suivre pour utiliser le point de contrôle automatique avec une multitranche.
Spécifiez le point de contrôle automatique lors de la création des ressources en file d'attente.
Un environnement multislice ne peut être provisionné que via une requête de ressources mise en file d'attente. Comme pour une tranche unique, utilisez l'option
autocheckpoint-enabled
dans l'appel pour créer une ressource en file d'attente.QR_ID=your-qr-id NODE_COUNT=your-node-count ACCELERATOR_TYPE=your-accelerator-type gcloud compute tpus queued-resources create $QR_ID \ --node-count $NODE_COUNT \ --accelerator-type $ACCELERATOR_TYPE \ --runtime-version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Consultez le guide de l'utilisateur Multislice pour en savoir plus sur toutes les options disponibles. Une fois que la ressource en file d'attente de la requête est créée et à l'état
ACTIVE
, suivez les étapes suivantes pour exécuter Pax avec Point de contrôle automatique.Installez Pax sur toutes les VM de l'environnement Multislice.
Sur les VM TPU, installez
jax[tpu]
et la dernière version depaxml
sur toutes les VM TPU de votre environnement multi-tranche:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Lancer l'entraînement avec la configuration appropriée
Cet exemple montre comment configurer le modèle
LmCloudSpmd2B
pour le point de contrôle automatique lors de l'entraînement dans un environnement à plusieurs tranches. Avant d'exécuter le script d'entraînement, définissez DCN_MESH_SHAPE sur [2, 1, 1], comme indiqué dans le code suivant :@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 4, 1] DCN_MESH_SHAPE = [2, 1, 1]
Lors du lancement de l'entraînement, en plus des indicateurs de ligne de commande abordés dans le cas d'une seule tranche, trois autres sont requis :
num_hosts
: nombre total d'hôtes. Dans ce cas, il s'agit de 2.host_index
: indice de l'hôte qui lance l'entraînement. Variable de 0 àN-1
, oùN
est le nombre total d'hôtes.server_addr
: adresse IP du nœud de calcul 0 du nœud 0, avec une adresse (8476, par exemple). Pour obtenir ces informations, utilisezhostname -i
sur le nœud de calcul 0.
Point de contrôle automatique avec Orbax
La fonctionnalité de point de contrôle automatique ne se limite pas à MaxText ou Pax. Tout framework capable de capturer le signal SIGTERM et d'initier un processus de point de contrôle fonctionne avec l'infrastructure fournie par Autocheckpoint. Orbax, un espace de noms qui fournit les bibliothèques d'utilitaires courantes pour les utilisateurs de JAX, fournit ces fonctionnalités.
Comme expliqué dans la documentation Orbax,
ces fonctionnalités sont activées par défaut
sur orbax.checkpoint.CheckpointManager
. La méthode save
appelé après chaque étape, il vérifie automatiquement
l'événement est imminent et, le cas échéant, enregistre un point de contrôle même si le numéro de l'étape
n'est pas un multiple de save_interval_steps
.
La documentation GitHub explique également comment arrêter l'entraînement après avoir enregistré un point de contrôle automatique, avec une modification dans le code utilisateur.