Cloud TPU Autocheckpoint [öffentliche Vorschau]

Überblick

Wenn für eine TPU-VM eine Wartung erforderlich ist, wird das Verfahren in der Vergangenheit sofort initiiert, ohne dass Nutzer Zeit haben, fortschrittserhaltende Aktionen wie das Speichern eines Prüfpunkts auszuführen. Dies ist in Abbildung 1(a) dargestellt.

Autocheckpoint

Abb. 1 Darstellung der Autocheckpoint-Funktion: (a) Ohne Autocheckpoint geht der Trainingsfortschritt ab dem letzten Prüfpunkt bei einem anstehenden Wartungsereignis verloren. (b) Mit Autocheckpoint kann der Trainingsfortschritt seit dem letzten Checkpoint beibehalten werden, wenn ein Wartungsereignis ansteht.

Sie können Autocheckpoint verwenden (Abbildung 1(b)), um den Trainingsfortschritt beizubehalten, indem Sie Ihren Code so konfigurieren, dass bei einem Wartungsereignis ein nicht geplanter Checkpoint gespeichert wird. Wenn ein Wartungsereignis auftritt, wird der Fortschritt seit dem letzten Checkpoint automatisch gespeichert. Das Feature kann sowohl auf einzelne Segmente als auch auf Mehrfachsegmente angewendet werden.

Die Autocheckpoint-Funktion funktioniert mit Frameworks, die SIGTERM erfassen und anschließend einen Prüfpunkt speichern können. Zu den unterstützten Frameworks gehören MaxText, Pax und JAX mit Orbax. Die Unterstützung weiterer Frameworks wird angekündigt, sobald sie verfügbar sind.

Derzeit können nur TPUs (v2-v4 und v5e), die mit der Cloud TPU API erstellt wurden, dieses Feature verwenden. Die Unterstützung für TPUs in GKE wird angekündigt, sobald sie verfügbar ist.

Autocheckpoint verwenden

Die Autocheckpoint-Funktion ist standardmäßig deaktiviert. Wenn Sie eine TPU oder eine Ressource in der Warteschlange erstellen, können Sie diese aktivieren, indem Sie bei der Bereitstellung der TPU das Flag --autocheckpoint-enabled hinzufügen. Wenn das Feature aktiviert ist, führt Cloud TPU die folgenden Schritte aus, sobald eine Benachrichtigung über ein Wartungsereignis eingeht:

  1. Erfassen Sie SIGTERM, das an den Prozess gesendet wurde, mit dem TPU-Gerät.
  2. Es wird gewartet, bis der Prozess beendet ist oder 5 Minuten vergangen sind, je nachdem, was zuerst eintritt, und dann werden die betroffenen Segmente gewartet.

Die von Autocheckpoint verwendete Infrastruktur ist unabhängig vom ML-Framework. Jedes ML-Framework kann Autocheckpoint unterstützen, sofern es das SIGTERM-Signal erfassen und einen Prüfpunktprozess initiieren kann.

Im Anwendungscode müssen Sie die vom ML-Framework bereitgestellten Autocheckpoint-Funktionen aktivieren. In Pax bedeutet dies beispielsweise, dass beim Start des Trainings Befehlszeilen-Flags aktiviert werden (siehe Autocheckpoint-Kurzanleitung mit Pax). Im Hintergrund speichern die Frameworks einen nicht geplanten Checkpoint, wenn ein SIGTERM empfangen wird und die betroffene TPU-VM einer Wartung unterzogen wird, wenn die TPU nicht mehr verwendet wird.

Kurzanleitung: Autocheckpoint mit MaxText

MaxText ist ein leistungsfähiges, beliebig skalierbares, gut getestetes Open-Source-LLM, das in reinem Python/JAX geschrieben wurde und auf Cloud TPUs ausgerichtet ist. MaxText enthält alle erforderlichen Einstellungen zur Verwendung der Autocheckpoint-Funktion.

In der MaxText-Readme-Datei werden zwei Möglichkeiten zum Ausführen von MaxText im großen Maßstab beschrieben:

Wenn Sie multihost_runner.py verwenden, muss nur das Flag autocheckpoint-enabled beim Bereitstellen der Ressource in der Warteschlange festgelegt werden. Wenn Sie multihost_job.py verwenden, muss nur das Befehlszeilen-Flag ENABLE_AUTOCHECKPOINT=true beim Starten des Jobs angegeben werden.

Kurzanleitung: Autocheckpoint mit Pax für einzelne Segmente

In diesem Abschnitt wird anhand eines Beispiels erläutert, wie Sie Autocheckpoint mit Pax für ein einzelnes Segment einrichten und verwenden. Bei richtiger Einrichtung:

  • Ein Checkpoint wird gespeichert, wenn ein Wartungsereignis eintritt.
  • Cloud TPU führt eine Wartung der betroffenen TPU-VM(s) durch, nachdem der Prüfpunkt gespeichert wurde.
  • Wenn Cloud TPU die Wartung abgeschlossen hat, können Sie die TPU-VM wie gewohnt verwenden.
  1. Verwenden Sie das Flag autocheckpoint-enabled, wenn Sie die TPU-VM oder die Ressource in der Warteschlange erstellen.

    Beispiel:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    
  2. Pax auf einem einzelnen Segment installieren

    Die Autocheckpoint-Funktion funktioniert auf Pax-Versionen ab 1.1.0. Installieren Sie jax[tpu] und die aktuelle Version von paxml auf den TPU-VMs:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Training mit der entsprechenden Konfiguration starten

    Das folgende Beispiel zeigt, wie Sie das Modell LmCloudSpmd2B so konfigurieren, dass Prüfpunkte, die von Autocheckpoint ausgelöst wurden, in einem Google Cloud Storage-Bucket gespeichert werden:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
    

    Beachten Sie die beiden Flags, die an den Befehl übergeben werden:

    • jax_fully_async_checkpoint: Wenn dieses Flag aktiviert ist, wird orbax.checkpoint.AsyncCheckpointer verwendet. Die Klasse AsyncCheckpointer speichert automatisch einen Checkpoint, wenn das Trainingsskript ein SIGTERM-Signal empfängt.
    • exit_after_ondemand_checkpoint: Wenn dieses Flag aktiviert ist, wird die TPU-Prozesse beendet, nachdem der Autocheckpoint erfolgreich gespeichert wurde. Dadurch wird die Wartung sofort ausgelöst. Wenn Sie dieses Flag nicht verwenden, wird das Training fortgesetzt, nachdem der Prüfpunkt gespeichert wurde. Cloud TPU wartet auf ein Zeitlimit (5 Minuten), bevor die erforderliche Wartung durchgeführt wird.

Kurzanleitung: Autocheckpoint mit Pax für Multislice

Autocheckpoint funktioniert nicht nur für einzelne Segmente, sondern auch für Multislice. In diesem Abschnitt werden die Schritte beschrieben, die zum Verwenden von Autocheckpoint mit Multislice erforderlich sind.

  1. Geben Sie beim Erstellen von Ressourcen in der Warteschlange einen Autocheckpoint an.

    Eine Umgebung mit mehreren Teilen kann nur über eine Ressourcenanfrage in der Warteschlange bereitgestellt werden. Ähnlich wie bei der Single-Slice-Methode verwenden Sie das Flag autocheckpoint-enabled im Aufruf, um eine Ressource in der Warteschlange zu erstellen.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud alpha compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    

    Weitere Informationen zu allen verfügbaren Optionen finden Sie im Nutzerhandbuch für mehrere Segmente. Sobald die Ressourcenanfrage in der Warteschlange erstellt wurde und den Status ACTIVE hat, führen Sie die nächsten Schritte aus, um Pax mit Autocheckpoint auszuführen.

  2. Installieren Sie Pax auf allen VMs in der Multislice-Umgebung.

    Installieren Sie auf den TPU-VMs jax[tpu] und die aktuelle Version von paxml auf allen TPU-VMs in Ihrer Multislice-Umgebung:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Training mit der entsprechenden Konfiguration starten

    In diesem Beispiel wird gezeigt, wie das Modell LmCloudSpmd2B für Autocheckpoint beim Training in einer Multislice-Umgebung konfiguriert wird. Bevor Sie das Trainingsskript ausführen, legen Sie DCN_MESH_SHAPE auf [2, 1, 1] fest, wie im folgenden Code gezeigt:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]
    

    Beim Starten des Trainings sind zusätzlich zu den Befehlszeilen-Flags, die im Einzel-Slice-Fall beschrieben wurden, drei weitere erforderlich:

    • num_hosts: die Gesamtzahl der Hosts In diesem Fall ist es 2.
    • host_index: Index des Hosts, der das Training startet. Er reicht von 0 bis N-1, wobei N die Gesamtzahl der Hosts ist.
    • server_addr: die IP-Adresse von Worker 0 von Knoten 0, mit einem nicht verwendeten Port (z. B. 8476). Verwenden Sie zum Ermitteln dieser Informationen hostname -i auf Worker 0 von Knoten 0.

Autocheckpoint mit Orbax

Die Autocheckpoint-Funktion ist nicht auf MaxText oder Pax beschränkt. Jedes Framework, das das SIGTERM-Signal erfassen und einen Prüfpunktprozess initiieren kann, funktioniert mit der von Autocheckpoint bereitgestellten Infrastruktur. Orbax, ein Namespace, der allgemeine Dienstprogrammbibliotheken für JAX-Nutzer bereitstellt, bietet diese Funktionen.

Wie in der Orbax-Dokumentation erläutert, sind diese Funktionen für Nutzer von orbax.checkpoint.CheckpointManager standardmäßig aktiviert. Die Methode save, die nach jedem Schritt aufgerufen wird, prüft automatisch, ob ein Wartungsereignis ansteht. Wenn ja, wird ein Prüfpunkt gespeichert, auch wenn die Schrittzahl kein Vielfaches von save_interval_steps ist. Die GitHub-Dokumentation veranschaulicht auch, wie das Training nach dem Speichern eines Autocheckpoint mit einer Änderung im Nutzercode beendet wird.