Point de contrôle automatique Cloud TPU [version Preview publique]

Présentation

Historiquement, lorsqu'une VM TPU nécessite une maintenance, la procédure est lancée immédiatement, sans laisser le temps aux utilisateurs d'effectuer des actions préservant la progression, telles que l'enregistrement d'un point de contrôle. Ce processus est illustré dans la figure 1(a).

point de contrôle automatique

Fig. 1. Illustration de la fonctionnalité de point de contrôle automatique : (a) Sans le point de contrôle automatique, la progression de l'entraînement depuis le dernier point de contrôle est perdue en cas d'événement de maintenance à venir. (b) Avec le point de contrôle automatique, la progression de l'entraînement depuis le dernier point de contrôle peut être préservée en cas d'événement de maintenance à venir.

Vous pouvez utiliser un point de contrôle automatique (Figure 1(b)) pour préserver la progression de l'entraînement. Pour ce faire, configurez votre code pour enregistrer un point de contrôle non planifié lorsqu'un événement de maintenance se produit. Lorsqu'un événement de maintenance se produit, la progression depuis le dernier point de contrôle est automatiquement enregistrée. Cette fonctionnalité s'applique aux tranches simples et multiples.

La fonctionnalité Autocheckpoint fonctionne avec des frameworks qui peuvent capturer le signal SIGTERM, puis enregistrer un point de contrôle. Les frameworks compatibles incluent MaxText, Pax et JAX avec Orbax. D'autres frameworks seront pris en charge dès qu'ils seront disponibles.

Seuls les TPU (v2-v4 et v5e) créés via l'API Cloud TPU peuvent utiliser cette fonctionnalité pour le moment. La compatibilité des TPU dans GKE sera annoncée lorsqu'elle sera disponible.

Utiliser Autocheckpoint

La fonctionnalité de point de contrôle automatique est désactivée par défaut. Lorsque vous créez un TPU ou une ressource en file d'attente, vous pouvez l'activer en ajoutant l'indicateur --autocheckpoint-enabled lors du provisionnement du TPU. Lorsque cette fonctionnalité est activée, Cloud TPU effectue les étapes suivantes après avoir reçu la notification d'un événement de maintenance:

  1. Capturez le SIGTERM envoyé au processus à l'aide du périphérique TPU,
  2. Attend que le processus se ferme ou que cinq minutes se soient écoulées, selon la première échéance atteinte, et effectue une maintenance sur les tranches concernées.

Notez que l'infrastructure utilisée par Autocheckpoint est indépendante du framework de ML. Tout framework de ML peut prendre en charge la fonctionnalité Autocheckpoint à condition qu'elle puisse capturer le signal SIGTERM et lancer un processus de création de points de contrôle.

Dans le code de l'application, vous devez activer les fonctionnalités Autocheckpoint fournies par le framework de ML. Dans Pax, par exemple, cela implique d'activer les indicateurs de ligne de commande lors du lancement de l'entraînement (consultez le guide de démarrage rapide sur les points de contrôle automatiques avec Pax). En arrière-plan, les frameworks enregistrent un point de contrôle non programmé lors de la réception d'un signal SIGTERM et la VM TPU concernée fait l'objet d'une maintenance lorsque le TPU n'est plus utilisé.

Guide de démarrage rapide: point de contrôle automatique avec MaxText

MaxText est un "LLM à hautes performances, arbitrairement évolutif, Open Source et bien testé, écrit en langage Python/JAX pur ciblant des Cloud TPU". MaxText contient toute la configuration nécessaire pour utiliser la fonctionnalité Autocheckpoint.

Le fichier README de MaxText décrit deux façons d'exécuter MaxText à grande échelle:

Avec multihost_runner.py, la seule modification requise consiste à définir l'option autocheckpoint-enabled lors du provisionnement de la ressource en file d'attente. Lorsque vous utilisez multihost_job.py, la seule modification requise consiste à spécifier l'indicateur de ligne de commande ENABLE_AUTOCHECKPOINT=true lors du lancement de la tâche.

Guide de démarrage rapide: point de contrôle automatique avec Pax sur des tranches uniques

Dans cette section, nous fournissons un exemple de configuration et d'utilisation du point de contrôle automatique avec Pax sur une seule tranche. Avec la configuration appropriée:

  • Un point de contrôle est enregistré lorsqu'un événement de maintenance se produit.
  • Une fois le point de contrôle enregistré, Cloud TPU effectue la maintenance sur la ou les VM TPU concernées.
  • Une fois la maintenance terminée, vous pouvez utiliser la VM TPU comme d'habitude.
  1. Utilisez l'indicateur autocheckpoint-enabled lors de la création de la VM TPU ou de la ressource en file d'attente.

    Exemple :

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    
  2. Installer Pax sur une seule tranche

    La fonctionnalité Autocheckpoint fonctionne sur les versions Pax 1.1.0 ou ultérieures. Sur les VM TPU, installez jax[tpu] et la dernière version de paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Lancer l'entraînement avec la configuration appropriée

    L'exemple suivant montre comment configurer le modèle LmCloudSpmd2B pour enregistrer les points de contrôle déclenchés par Autocheckpoint dans un bucket Google Cloud Storage:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
    

    Notez les deux indicateurs transmis à la commande:

    • jax_fully_async_checkpoint : lorsque cet indicateur est activé, orbax.checkpoint.AsyncCheckpointer est utilisé. La classe AsyncCheckpointer enregistre automatiquement un point de contrôle lorsque le script d'entraînement reçoit un signal SIGTERM.
    • exit_after_ondemand_checkpoint : lorsque cet indicateur est activé, les processus TPU se ferment une fois le point de contrôle automatique enregistré, ce qui déclenche la maintenance immédiatement. Si vous n'utilisez pas cet indicateur, l'entraînement se poursuit après l'enregistrement du point de contrôle, et Cloud TPU attend un délai avant expiration (cinq minutes) avant d'effectuer la maintenance requise.

Guide de démarrage rapide: point de contrôle automatique avec Pax sur multitranche

Le point de contrôle automatique fonctionne non seulement pour les tranches uniques, mais également pour les tranches multiples. Cette section détaille les étapes nécessaires à l'utilisation du point de contrôle automatique avec la multitranche.

  1. Spécifiez le point de contrôle automatique lors de la création des ressources en file d'attente.

    Un environnement à plusieurs tranches ne peut être provisionné que via une requête de ressource en file d'attente. Comme pour le cas d'une tranche unique, utilisez l'indicateur autocheckpoint-enabled dans l'appel pour créer une ressource en file d'attente.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud alpha compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    

    Reportez-vous au guide de l'utilisateur pour les secteurs multiples pour en savoir plus sur toutes les options disponibles. Une fois la demande de ressources en file d'attente créée et à l'état ACTIVE, suivez les étapes suivantes pour exécuter Pax avec Autocheckpoint.

  2. Installer Pax sur toutes les VM de l'environnement multitranche

    Sur les VM TPU, installez jax[tpu] et la dernière version de paxml sur toutes les VM TPU de votre environnement multitranche:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Lancer l'entraînement avec la configuration appropriée

    Cet exemple montre comment configurer le modèle LmCloudSpmd2B pour Autocheckpoint lors de l'entraînement dans un environnement à plusieurs tranches. Avant d'exécuter le script d'entraînement, définissez DCN_MESH_SHAPE sur [2, 1, 1] comme indiqué dans le code suivant:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]
    

    Lors du lancement de l'entraînement, en plus des indicateurs de ligne de commande abordés dans le cas d'une tranche unique, trois autres sont requis:

    • num_hosts: nombre total d'hôtes. Dans ce cas, c’est 2.
    • host_index: index de l'hôte qui lance l'entraînement. Il varie de 0 à N-1, où N correspond au nombre total d'hôtes.
    • server_addr: adresse IP du nœud de calcul 0 du nœud 0, avec un port inutilisé (par exemple, 8476). Pour trouver ces informations, utilisez hostname -i sur le nœud de calcul 0 du nœud 0.

Point de contrôle automatique avec Orbax

La fonctionnalité Autocheckpoint n'est pas limitée à MaxText ou Pax. Tout framework capable de capturer le signal SIGTERM et de lancer un processus de création de points de contrôle fonctionne avec l'infrastructure fournie par Autocheckpoint. Orbax, un espace de noms qui fournit des bibliothèques d'utilitaires courantes pour les utilisateurs JAX, offre ces fonctionnalités.

Comme expliqué dans la documentation Orbax, ces fonctionnalités sont activées par défaut pour les utilisateurs de orbax.checkpoint.CheckpointManager. La méthode save appelée après chaque étape vérifie automatiquement si un événement de maintenance est imminent et, le cas échéant, enregistre un point de contrôle même si le numéro d'étape n'est pas un multiple de save_interval_steps. La documentation GitHub montre également comment fermer l'entraînement après l'enregistrement d'un point de contrôle automatique, avec une modification du code utilisateur.