Cloud TPU Autocheckpoint [Preview public]

Présentation

Historiquement, lorsqu'une VM TPU nécessite une maintenance, la procédure est lancée immédiatement, sans laisser aux utilisateurs le temps d'effectuer des actions de conservation de la progression, telles que l'enregistrement d'un point de contrôle. C'est ce que montre la figure 1(a).

autocheckpoint

Fig. 1. Illustration de la fonctionnalité de point de contrôle automatique : (a) Sans point de contrôle automatique, la progression de l'entraînement à partir du dernier point de contrôle est perdue en cas d'événement de maintenance à venir. (b) Avec le point de contrôle automatique, la progression de l'entraînement depuis le dernier point de contrôle peut être préservée en cas d'événement de maintenance à venir.

Vous pouvez utiliser le point de contrôle automatique (figure 1b) pour préserver la progression de l'entraînement en configurant votre code pour enregistrer un point de contrôle non planifié lorsqu'un événement de maintenance se produit. Lorsqu'un événement de maintenance se produit, la progression depuis le dernier point de contrôle est automatiquement enregistrée. Cette fonctionnalité fonctionne à la fois sur les tranches uniques et sur Multislice.

La fonctionnalité Autocheckpoint fonctionne avec les frameworks pouvant capturer SIGTERM et enregistrer ensuite un point de contrôle. Les frameworks compatibles incluent MaxText, Pax et JAX avec Orbax. La prise en charge de frameworks supplémentaires sera annoncée à mesure qu'ils seront disponibles.

Pour le moment, seuls les TPU (versions v2 à v4 et v5e) créés via l'API Cloud TPU peuvent utiliser cette fonctionnalité. La prise en charge des TPU dans GKE sera annoncée lorsqu'elle sera disponible.

Utiliser le point de contrôle automatique

La fonctionnalité de point de contrôle automatique est désactivée par défaut. Lorsque vous créez un TPU ou une ressource mise en file d'attente, vous pouvez l'activer en ajoutant l'indicateur --autocheckpoint-enabled lors du provisionnement du TPU. Lorsque cette fonctionnalité est activée, Cloud TPU effectue les étapes suivantes lorsqu'il reçoit une notification d'événement de maintenance:

  1. Capturez le signal SIGTERM envoyé au processus à l'aide de l'appareil TPU.
  2. Attend que le processus se termine ou que cinq minutes se soient écoulées, selon la première éventualité, puis effectue la maintenance des tranches concernées.

Notez que l'infrastructure utilisée par Autocheckpoint est indépendante du framework de ML. Tout framework de ML peut prendre en charge Autocheckpoint à condition qu'il puisse capturer le signal SIGTERM et lancer un processus de création de points de contrôle.

Dans le code de l'application, vous devez activer les fonctionnalités de point de contrôle automatique fournies par le framework de ML. Dans Pax, par exemple, cela signifie activer les indicateurs de ligne de commande lors du lancement de l'entraînement (voir le démarrage rapide de l'autocheckpoint avec Pax). En coulisses, les frameworks enregistrent un point de contrôle non planifié lorsqu'un SIGTERM est reçu et que la VM TPU concernée est soumise à une maintenance lorsque le TPU n'est plus utilisé.

Guide de démarrage rapide: Point de contrôle automatique avec MaxText

MaxText est un LLM Open Source hautes performances, évolutif de manière arbitraire, bien testé et écrit en Python/JAX pur, qui cible les Cloud TPU. MaxText contient toute la configuration nécessaire pour utiliser la fonctionnalité de point de contrôle automatique.

Le fichier README de MaxText décrit deux façons d'exécuter MaxText à grande échelle:

Lorsque vous utilisez multihost_runner.py, la seule modification requise consiste à définir l'indicateur autocheckpoint-enabled lors du provisionnement de la ressource mise en file d'attente. Lorsque vous utilisez multihost_job.py, la seule modification requise consiste à spécifier l'indicateur de ligne de commande ENABLE_AUTOCHECKPOINT=true lors du lancement de la tâche.

Démarrage rapide: point de contrôle automatique avec Pax sur des tranches uniques

Dans cette section, nous vous expliquons comment configurer et utiliser Autocheckpoint avec Pax sur une seule tranche. Avec la configuration appropriée:

  • Un point de contrôle est enregistré lorsqu'un événement de maintenance se produit.
  • Cloud TPU effectuera la maintenance des VM TPU concernées une fois le point de contrôle enregistré.
  • Une fois la maintenance terminée, vous pouvez utiliser la VM TPU comme d'habitude.
  1. Utilisez l'option autocheckpoint-enabled lorsque vous créez la VM TPU ou la ressource mise en file d'attente.

    Exemple :

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
  2. Installer Pax sur une seule tranche

    La fonctionnalité de point de contrôle automatique fonctionne avec les versions Pax 1.1.0 et ultérieures. Sur les VM TPU, installez jax[tpu] et la dernière version de paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. Lancer l'entraînement avec la configuration appropriée

    L'exemple suivant montre comment configurer le modèle LmCloudSpmd2B pour enregistrer les points de contrôle déclenchés par Autocheckpoint dans un bucket Google Cloud Storage:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Notez les deux options transmises à la commande:

    • jax_fully_async_checkpoint : lorsque cette option est activée, orbax.checkpoint.AsyncCheckpointer est utilisé. La classe AsyncCheckpointer enregistre automatiquement un point de contrôle lorsque le script d'entraînement reçoit un signal SIGTERM.
    • exit_after_ondemand_checkpoint : lorsque cet indicateur est activé, les processus TPU se terminent une fois le point de contrôle automatique enregistré, ce qui déclenche immédiatement la maintenance. Si vous n'utilisez pas cet indicateur, l'entraînement se poursuit après l'enregistrement du point de contrôle, et Cloud TPU attend un délai avant d'effectuer la maintenance requise (cinq minutes).

Démarrage rapide: point de contrôle automatique avec Pax sur Multislice

Le point de contrôle automatique fonctionne non seulement pour les tranches uniques, mais aussi pour les tranches multiples. Cette section décrit les étapes à suivre pour utiliser Autocheckpoint avec Multislice.

  1. Spécifiez le point de contrôle automatique lors de la création de ressources en file d'attente.

    Un environnement multislice ne peut être provisionné que via une requête de ressources mise en file d'attente. Comme pour le cas d'une seule tranche, utilisez l'indicateur autocheckpoint-enabled dans l'appel pour créer une ressource mise en file d'attente.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled

    Consultez le guide de l'utilisateur Multislice pour en savoir plus sur toutes les options disponibles. Une fois la requête de ressources mise en file d'attente créée et à l'état ACTIVE, suivez les étapes suivantes pour exécuter Pax avec Autocheckpoint.

  2. Installez Pax sur toutes les VM de l'environnement Multislice.

    Sur les VM TPU, installez jax[tpu] et la dernière version de paxml sur toutes les VM TPU de votre environnement multislice:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. Lancer l'entraînement avec la configuration appropriée

    Cet exemple montre comment configurer le modèle LmCloudSpmd2B pour le point de contrôle automatique lors de l'entraînement dans un environnement multicouche. Avant d'exécuter le script d'entraînement, définissez DCN_MESH_SHAPE sur [2, 1, 1], comme indiqué dans le code suivant:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]

    Lors du lancement de l'entraînement, en plus des indicateurs de ligne de commande abordés dans le cas d'une seule tranche, trois autres sont requis:

    • num_hosts: nombre total d'hôtes. Dans ce cas, il s'agit de 2.
    • host_index: indice de l'hôte qui lance l'entraînement. Il varie de 0 à N-1, où N est le nombre total d'hôtes.
    • server_addr: adresse IP du nœud de travail 0 du nœud 0, avec un port inutilisé (par exemple, 8476). Pour obtenir ces informations, utilisez hostname -i sur le nœud de calcul 0.

Point de contrôle automatique avec Orbax

La fonctionnalité de contrôle automatique n'est pas limitée à MaxText ou Pax. Tout framework capable de capturer le signal SIGTERM et d'initier un processus de point de contrôle fonctionne avec l'infrastructure fournie par Autocheckpoint. Orbax, un espace de noms qui fournit des bibliothèques d'utilitaires courantes pour les utilisateurs de JAX, fournit ces fonctionnalités.

Comme expliqué dans la documentation Orbax, ces fonctionnalités sont activées par défaut pour les utilisateurs de orbax.checkpoint.CheckpointManager. La méthode save appelée après chaque étape vérifie automatiquement si un événement de maintenance est imminent. Si c'est le cas, elle enregistre un point de contrôle, même si le numéro d'étape n'est pas un multiple de save_interval_steps. La documentation GitHub explique également comment arrêter l'entraînement après avoir enregistré un point de contrôle automatique, avec une modification dans le code utilisateur.