Trainingsfortschritt mit Autocheckpoint sichern

Wenn eine TPU-VM gewartet werden musste, wurde der Vorgang bisher sofort eingeleitet. Den Nutzern blieb keine Zeit, mit entsprechenden Aktionen den Fortschritt zu sichern, zum Beispiel einen Prüfpunkt zu speichern. Dies ist in Abbildung 1(a) dargestellt.

Diagramm, das die Auswirkungen der Host-Wartung mit und ohne automatisch erstellte Prüfpunkte zeigt

Abbildung 1 Darstellung der Funktion Autocheckpoint: (a) Ohne automatisch erstellte Prüfpunkte geht der Trainingsfortschritt seit dem letzten Prüfpunkt verloren, wenn eine Wartung ansteht. (b) Mit Autocheckpoint bleibt der Trainingsfortschritt ab dem letzten Prüfpunkt erhalten, wenn eine Wartung ansteht.

Mit Autocheckpoint (Abbildung 1(b)) können Sie den Trainingsfortschritt erhalten. Dazu konfigurieren Sie Ihren Code so, dass bei einem Wartungsereignis ein nicht geplanter Prüfpunkt gespeichert wird. Tritt ein Wartungsereignis ein, wird der Fortschritt seit dem letzten Prüfpunkt automatisch gespeichert. Die Funktion ist sowohl für einzelne Slices als auch für Multislice verfügbar.

Die Funktion Autocheckpoint funktioniert mit Frameworks, die SIGTERM-Signale erfassen und anschließend einen Prüfpunkt speichern können. Folgende Frameworks werden unterstützt:

Autocheckpoint verwenden

Die Funktion Autocheckpoint ist standardmäßig deaktiviert. Sie können die Funktion aktivieren, wenn Sie eine TPU erstellen oder eine Ressource aus der Warteschlange anfordern. Fügen Sie dazu beim Bereitstellen der TPU das Flag --autocheckpoint-enabled hinzu. Bei aktivierter Funktion führt Cloud TPU die folgenden Schritte aus, sobald eine Benachrichtigung über ein Wartungsereignis eingeht:

  1. SIGTERM-Signal erfassen, das mit dem TPU-Gerät an den Prozess gesendet wird
  2. Warten, bis der Prozess beendet wird oder 5 Minuten vergangen sind, je nachdem, was zuerst eintritt.
  3. Wartung für die betroffenen Slices durchführen

Die von der Funktion Autocheckpoint verwendete Infrastruktur ist vom ML-Framework unabhängig. Die Funktion wird von jedem ML-Framework unterstützt, das das SIGTERM-Signal erfassen und einen Prüfpunktprozess initiieren kann.

Im Anwendungscode müssen Sie die vom ML-Framework bereitgestellten Prüfpunkt-Funktionen aktivieren. In Pax bedeutet dies beispielsweise, dass beim Starten des Trainings Befehlszeilen-Flags aktiviert werden müssen. Weitere Informationen finden Sie unter Kurzanleitung: Autocheckpoint mit Pax. Wenn ein SIGTERM-Signal empfangen wird, speichern die Frameworks im Hintergrund einen nicht geplanten Prüfpunkt, und die betroffene TPU-VM wird gewartet, sobald die TPU nicht mehr verwendet wird.

Kurzanleitung: Autocheckpoint mit MaxText

MaxText ist ein leistungsstarkes, beliebig skalierbares, Open-Source-LLM, das in reinem Python/JAX für Cloud TPUs geschrieben wurde und ausreichend Tests unterzogen wurde. MaxText enthält alle notwendigen Einstellungen, um die Funktion Autocheckpoint zu nutzen.

In der MaxText-Datei README werden zwei Möglichkeiten beschrieben, MaxText im großen Maßstab auszuführen:

Wenn Sie multihost_runner.py verwenden, aktivieren Sie Autocheckpoint, indem Sie beim Bereitstellen der Ressource aus der Warteschlange das Flag autocheckpoint-enabled setzen.

Wenn Sie multihost_job.py verwenden, aktivieren Sie Autocheckpoint, indem Sie beim Starten des Jobs das Befehlszeilenflag ENABLE_AUTOCHECKPOINT=true angeben.

Kurzanleitung: Autocheckpoint mit Pax auf einem einzelnen Slice

In diesem Abschnitt finden Sie ein Beispiel dafür, wie Sie Autocheckpoint mit Pax auf einem einzelnen Slice einrichten und verwenden. Bei entsprechender Einrichtung:

  • Bei einem Wartungsereignis wird ein Prüfpunkt gespeichert.
  • Cloud TPU führt nach dem Speichern des Prüfpunkts die Wartung der betroffenen TPU-VM(s) durch.
  • Nach abgeschlossener Wartung können Sie die TPU-VM wie gewohnt verwenden.
  1. Verwenden Sie das Flag autocheckpoint-enabled, wenn Sie die TPU-VM erstellen oder eine Ressource aus der Warteschlange anfordern.

    Beispiel:

    1. Legen Sie Umgebungsvariablen fest:

      export PROJECT_ID=your-project-id
      export TPU_NAME=your-tpu-name
      export ZONE=zone-you-want-to-use
      export ACCELERATOR_TYPE=your-accelerator-type
      export RUNTIME_VERSION=tpu-ubuntu2204-base

      Beschreibungen von Umgebungsvariablen

      Variable Beschreibung
      PROJECT_ID Ihre Projekt-ID in Google Cloud . Verwenden Sie ein vorhandenes oder erstellen Sie ein neues Projekt.
      TPU_NAME Der Name der TPU.
      ZONE Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen.
      ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
      RUNTIME_VERSION Die Softwareversion der Cloud TPU.

    2. Legen Sie Ihre Projekt-ID und Zone in Ihrer aktiven Konfiguration fest:

      gcloud config set project $PROJECT_ID
      gcloud config set compute/zone $ZONE
    3. Erstellen Sie eine TPU:

      gcloud alpha compute tpus tpu-vm create $TPU_NAME \
          --accelerator-type $ACCELERATOR_TYPE \
          --version $RUNTIME_VERSION \
          --autocheckpoint-enabled
  2. Stellen Sie eine SSH-Verbindung zur TPU her:

    gcloud compute tpus tpu-vm ssh $TPU_NAME
    
  3. Installieren Sie Pax auf einem einzelnen Slice

    Die Funktion Autocheckpoint ist in den Pax-Versionen 1.1.0 und höher verfügbar. Installieren Sie jax[tpu] und die aktuelle Version von paxml auf der TPU-VM:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. Konfigurieren Sie das Modell LmCloudSpmd2B. Bevor Sie das Trainingsskript ausführen, ändern Sie ICI_MESH_SHAPE in [1, 8, 1]:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 8, 1]
  5. Starten Sie das Training mit der entsprechenden Konfiguration.

    Im folgenden Beispiel wird gezeigt, wie Sie das LmCloudSpmd2B-Modell konfigurieren, um die von Autocheckpoint ausgelösten Prüfpunkte in einem Cloud Storage-Bucket zu speichern. Ersetzen Sie your-storage-bucket durch den Namen eines vorhandenen Buckets oder erstellen Sie einen neuen Bucket.

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Beachten Sie die beiden Flags, die an den Befehl übergeben werden:

    • jax_fully_async_checkpoint: Wenn dieses Flag aktiviert ist, wird orbax.checkpoint.AsyncCheckpointer verwendet. Die Klasse AsyncCheckpointer speichert automatisch einen Prüfpunkt, wenn das Trainingsskript ein SIGTERM-Signal empfängt.
    • exit_after_ondemand_checkpoint: Wenn dieses Flag aktiviert ist, wird der TPU-Prozess beendet, nachdem der automatische Prüfpunkt erfolgreich gespeichert wurde. Dadurch wird die Wartung sofort ausgeführt. Wenn Sie dieses Flag nicht angeben, wird das Training nach dem Speichern des Prüfpunkts fortgesetzt und die Cloud TPU wartet bis zum Auftreten einer Zeitüberschreitung (5 Minuten), bevor die erforderliche Wartung durchgeführt wird.

Autocheckpoint mit Orbax

Die Funktion Autocheckpoint ist nicht auf MaxText oder Pax beschränkt. Jedes Framework, das das SIGTERM-Signal erfassen und einen Prüfpunktprozess starten kann, funktioniert mit der von Autocheckpoint bereitgestellten Infrastruktur. Orbax, ein Namespace, der JAX-Nutzern allgemeine Utility-Bibliotheken bereitstellt, bietet diese Funktionen.

Wie in der Orbax-Dokumentation erläutert, sind diese Funktionen standardmäßig für Nutzer von orbax.checkpoint.CheckpointManager aktiviert. Die save-Methode, die nach jedem Schritt aufgerufen wird, prüft automatisch, ob ein Wartungsereignis bevorsteht. Wenn dies der Fall ist, wird ein Prüfpunkt gespeichert, auch wenn die Schrittnummer kein Vielfaches von save_interval_steps ist. In der GitHub-Dokumentation wird auch gezeigt, wie das Training nach dem Speichern eines Autocheckpoints beendet werden kann, indem der Nutzercode geändert wird.