Cloud TPU Autocheckpoint [öffentliche Vorschau]

Übersicht

Wenn eine TPU-VM in der Vergangenheit Wartung Der Vorgang wird sofort eingeleitet, ohne dass die Nutzenden Zeit haben, Fortschrittserhaltende Aktionen ausführen, z. B. einen Prüfpunkt speichern Dies ist wie in Abbildung 1(a) dargestellt.

automatischer Prüfpunkt

Abbildung 1 Abbildung der Funktion „Autocheckpoint“: (a) Ohne Autocheckpoint geht der Trainingsfortschritt seit dem letzten Checkpoint verloren, wenn ein Wartungsereignis bevorsteht. (b) Mit dem automatischen Prüfpunkt Trainingsfortschritt seit der letzten Checkpoint kann bei einem bevorstehenden Wartungsereignis beibehalten werden.

Mit dem automatischen Checkpoint (Abbildung 1(b)) können Sie den Trainingsfortschritt beibehalten, indem Sie Ihren Code so konfigurieren, dass ein nicht geplanter Checkpoint gespeichert wird, wenn ein Wartungsereignis auftritt. Wenn ein Wartungsereignis auftritt, wird der Fortschritt seit dem letzten Checkpoint automatisch gespeichert. Das Feature kann auf beide Segmente angewendet werden. und Multislice.

Die automatische Prüfpunktfunktion arbeitet mit Frameworks, SIGTERM und speichern Sie anschließend einen Prüfpunkt. Folgende Frameworks werden unterstützt: MaxText, Pax, und JAX mit Orbax Unterstützung für weitere Frameworks wird bekannt gegeben, sobald sie verfügbar sind.

Derzeit können nur TPUs (v2–v4 und v5e), die über die Cloud TPU API erstellt wurden, diese Funktion nutzen. Unterstützung für TPUs in GKE ist ankündigen, sobald sie verfügbar sind.

Automatischen Prüfpunkt verwenden

Die automatische Prüfpunktfunktion ist standardmäßig deaktiviert. Wenn Sie eine TPU oder eine Ressource in der Warteschlange können Sie sie aktivieren, indem Sie bei der Bereitstellung das Flag --autocheckpoint-enabled hinzufügen. die TPU. Wenn das Feature aktiviert ist, führt die folgenden Schritte aus, sobald eine Benachrichtigung Wartungsereignis:

  1. SIGTERM erfassen, das über das TPU-Gerät an den Prozess gesendet wird,
  2. Wartet, bis der Prozess beendet ist oder 5 Minuten verstrichen sind, je nachdem, was zuerst eintritt, und führt Wartungsarbeiten an den betroffenen Slices durch.

Die von Autocheckpoint verwendete Infrastruktur ist unabhängig vom ML-Framework. Jedes ML-Framework kann Autocheckpoint unterstützen, sofern es das SIGTERM-Signal erfassen und einen Prüfpunktprozess initiieren kann.

Im Anwendungscode müssen Sie die vom ML-Framework bereitgestellten Funktionen für automatische Checkpoints aktivieren. In Pax beispielsweise Das bedeutet, dass beim Starten der Training (siehe Autocheckpoint-Kurzanleitung mit Pax). Hinter den Kulissen speichern die Frameworks Nicht geplanter Prüfpunkt beim Empfang eines SIGTERM und die betroffene TPU-VM gewartet wird, wenn die TPU nicht mehr verwendet wird.

Kurzanleitung: Autocheckpoint mit MaxText

MaxText ist ein „hochleistungsfähiges, beliebig skalierbares, Open-Source-LLM, das in reiner Python/JAX geschrieben wurde und auf Cloud TPUs ausgerichtet ist“. MaxText enthält alle notwendigen Einstellungen zur Verwendung des automatischen Prüfpunkts .

In der MaxText-Readme werden zwei Möglichkeiten zum Ausführen von MaxText im großen Maßstab beschrieben:

Bei Verwendung von multihost_runner.py ist die einzige Änderung erforderlich das Flag autocheckpoint-enabled bei der Bereitstellung in der Warteschlange. Wenn Sie multihost_job.py verwenden, müssen Sie beim Starten des Jobs nur das Befehlszeilenflag ENABLE_AUTOCHECKPOINT=true angeben.

Kurzanleitung: Autocheckpoint mit Pax für einzelne Segmente

In diesem Abschnitt findest du ein Beispiel für die Einrichtung und Verwendung von Autocheckpoint mit Pax auf einem einzelnen Slab. Mit der entsprechenden Einrichtung:

  • Wenn ein Wartungsereignis auftritt, wird ein Checkpoint gespeichert.
  • Nach dem Speichern des Checkpoints führt Cloud TPU Wartungsarbeiten an den betroffenen TPU-VMs durch.
  • Sobald die Wartung von Cloud TPU abgeschlossen ist, können Sie die TPU-VM wie gewohnt verwenden.
  1. Verwenden Sie das Flag autocheckpoint-enabled beim Erstellen der TPU-VM oder Ressource in der Warteschlange.

    Beispiel:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
  2. Pax in einem einzelnen Slice installieren

    Die Funktion „Autocheckpoint“ funktioniert mit Pax-Versionen ab 1.1.0. Installieren Sie auf den TPU-VMs jax[tpu] und die neueste Version von paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. Training mit der entsprechenden Konfiguration starten

    Das folgende Beispiel zeigt, wie LmCloudSpmd2B konfiguriert wird Modell zum Speichern der vom automatischen Prüfpunkt ausgelösten Prüfpunkte in einem Google Cloud Storage-Bucket:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Beachten Sie die beiden Flags, die an den Befehl übergeben werden:

    • jax_fully_async_checkpoint: Wenn dieses Flag aktiviert ist, wird orbax.checkpoint.AsyncCheckpointer verwendet. Der AsyncCheckpointer-Kurs speichert automatisch einen wenn das Trainingsskript ein SIGTERM-Signal empfängt.
    • exit_after_ondemand_checkpoint: Wenn dieses Flag aktiviert ist, werden die TPU-Prozesse beendet, nachdem der Autocheckpoint erfolgreich gespeichert wurde. Dadurch wird die Wartung sofort ausgeführt. Wenn Sie diese Flag gesetzt, wird das Training fortgesetzt, nachdem der Prüfpunkt gespeichert wurde. und Cloud TPU wartet auf ein Zeitlimit (5 Minuten) bevor Sie die erforderliche Wartung durchführen.

Kurzanleitung: Autocheckpoint mit Pax auf Multislice

Der automatische Prüfpunkt funktioniert nicht nur für einzelne Segmente, für Multislice. In diesem Abschnitt wird beschrieben, wie Sie automatische Checkpoints mit Multislice verwenden.

  1. Geben Sie „Autocheckpoint“ an, wenn Sie Ressourcen in der Warteschlange erstellen.

    Eine Multislice-Umgebung kann nur über eine Warteschlange bereitgestellt werden Ressourcenanfrage. Ähnlich wie beim Fall mit einer einzelnen Scheibe verwenden Sie das Flag autocheckpoint-enabled im Aufruf, um eine Ressourcenwarteschlange zu erstellen.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled

    Weitere Informationen finden Sie im Nutzerhandbuch für mehrere Segmente. . Sobald die Ressource in der Warteschlange Anfrage erstellt Führen Sie im Status ACTIVE die nächsten Schritte aus, um Pax mit Automatischer Prüfpunkt.

  2. Pax auf allen VMs in der Multislice-Umgebung installieren

    Installieren Sie auf den TPU-VMs jax[tpu] und die aktuelle Version von paxml. auf allen TPU-VMs in Ihrer Multislice-Umgebung:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. Training mit der entsprechenden Konfiguration starten

    In diesem Beispiel wird gezeigt, wie Sie das Modell LmCloudSpmd2B für automatische Checkpoints konfigurieren, wenn Sie in einer Multislice-Umgebung trainieren. Vorher des Trainingsskripts ausführen, setzen Sie DCN_MESH_SHAPE auf [2, 1, 1], wie in den folgenden Code:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]

    Beim Starten des Trainings sind zusätzlich zu den Befehlszeilenoptionen, die im Fall mit einer einzelnen Scheibe beschrieben wurden, noch drei weitere erforderlich:

    • num_hosts: die Gesamtzahl der Hosts. In diesem Fall ist das 2.
    • host_index: der Index des Hosts, der das Training startet. Abweichend von 0 bis N-1, wobei N die Gesamtzahl der Hosts ist.
    • server_addr: die IP-Adresse von Worker 0 von Knoten 0 mit einem nicht verwendeten Port (z. B. 8476). Verwenden Sie dazu hostname -i auf Worker 0 von Knoten 0.

Automatischer Checkpoint mit Orbax

Die Funktion „Autocheckpoint“ ist nicht auf MaxText oder Pax beschränkt. Jedes Framework, das das SIGTERM-Signal erfassen und einen Checkpoint-Prozess initiieren kann, funktioniert mit der von Autocheckpoint bereitgestellten Infrastruktur. Orbax, ein Namespace zur Bereitstellung von gängige Dienstprogrammbibliotheken für JAX-Nutzer.

Wie in der Orbax-Dokumentation erläutert, Diese Funktionen sind für Nutzer von orbax.checkpoint.CheckpointManager. Die Methode save der nach jedem Schritt aufgerufen wird, wird automatisch geprüft, bevor ein Ereignis ansteht. Falls ja, wird ein Prüfpunkt gespeichert, selbst wenn die Schrittzahl ist kein Vielfaches von save_interval_steps. In der GitHub-Dokumentation wird auch veranschaulicht, wie das Training nach dem Speichern eines automatischen Checkpoints beendet werden kann, indem der Nutzercode geändert wird.