Trainingsfortschritt mit Autocheckpoint sichern
Wenn eine TPU-VM gewartet werden musste, wurde der Vorgang bisher sofort eingeleitet. Den Nutzern blieb keine Zeit, mit entsprechenden Aktionen den Fortschritt zu sichern, zum Beispiel einen Prüfpunkt zu speichern. Dies ist in Abbildung 1(a) dargestellt.
Abbildung 1 Darstellung der Funktion Autocheckpoint: (a) Ohne automatisch erstellte Prüfpunkte geht der Trainingsfortschritt seit dem letzten Prüfpunkt verloren, wenn eine Wartung ansteht. (b) Mit Autocheckpoint bleibt der Trainingsfortschritt ab dem letzten Prüfpunkt erhalten, wenn eine Wartung ansteht.
Mit Autocheckpoint (Abbildung 1(b)) können Sie den Trainingsfortschritt erhalten. Dazu konfigurieren Sie Ihren Code so, dass bei einem Wartungsereignis ein nicht geplanter Prüfpunkt gespeichert wird. Tritt ein Wartungsereignis ein, wird der Fortschritt seit dem letzten Prüfpunkt automatisch gespeichert. Die Funktion ist sowohl für einzelne Slices als auch für Multislice verfügbar.
Die Funktion Autocheckpoint funktioniert mit Frameworks, die SIGTERM-Signale erfassen und anschließend einen Prüfpunkt speichern können. Folgende Frameworks werden unterstützt:
Autocheckpoint verwenden
Die Funktion Autocheckpoint ist standardmäßig deaktiviert. Sie können die Funktion aktivieren, wenn Sie eine TPU erstellen oder eine Ressource aus der Warteschlange anfordern. Fügen Sie dazu beim Bereitstellen der TPU das Flag --autocheckpoint-enabled
hinzu.
Bei aktivierter Funktion führt Cloud TPU die folgenden Schritte aus, sobald eine Benachrichtigung über ein Wartungsereignis eingeht:
- SIGTERM-Signal erfassen, das mit dem TPU-Gerät an den Prozess gesendet wird
- Warten, bis der Prozess beendet wird oder 5 Minuten vergangen sind, je nachdem, was zuerst eintritt.
- Wartung für die betroffenen Slices durchführen
Die von der Funktion Autocheckpoint verwendete Infrastruktur ist vom ML-Framework unabhängig. Die Funktion wird von jedem ML-Framework unterstützt, das das SIGTERM-Signal erfassen und einen Prüfpunktprozess initiieren kann.
Im Anwendungscode müssen Sie die vom ML-Framework bereitgestellten Prüfpunkt-Funktionen aktivieren. In Pax bedeutet dies beispielsweise, dass beim Starten des Trainings Befehlszeilen-Flags aktiviert werden müssen. Weitere Informationen finden Sie unter Kurzanleitung: Autocheckpoint mit Pax. Wenn ein SIGTERM-Signal empfangen wird, speichern die Frameworks im Hintergrund einen nicht geplanten Prüfpunkt, und die betroffene TPU-VM wird gewartet, sobald die TPU nicht mehr verwendet wird.
Kurzanleitung: Autocheckpoint mit MaxText
MaxText ist ein leistungsstarkes, beliebig skalierbares, Open-Source-LLM, das in reinem Python/JAX für Cloud TPUs geschrieben wurde und ausreichend Tests unterzogen wurde. MaxText enthält alle notwendigen Einstellungen, um die Funktion Autocheckpoint zu nutzen.
In der MaxText-Datei README
werden zwei Möglichkeiten beschrieben, MaxText im großen Maßstab auszuführen:
- mit
multihost_runner.py
(für Tests empfohlen) - mit
multihost_job.py
(für die Produktion empfohlen)
Wenn Sie multihost_runner.py
verwenden, aktivieren Sie Autocheckpoint, indem Sie beim Bereitstellen der Ressource aus der Warteschlange das Flag autocheckpoint-enabled
setzen.
Wenn Sie multihost_job.py
verwenden, aktivieren Sie Autocheckpoint, indem Sie beim Starten des Jobs das Befehlszeilenflag ENABLE_AUTOCHECKPOINT=true
angeben.
Kurzanleitung: Autocheckpoint mit Pax auf einem einzelnen Slice
In diesem Abschnitt finden Sie ein Beispiel dafür, wie Sie Autocheckpoint mit Pax auf einem einzelnen Slice einrichten und verwenden. Bei entsprechender Einrichtung:
- Bei einem Wartungsereignis wird ein Prüfpunkt gespeichert.
- Cloud TPU führt nach dem Speichern des Prüfpunkts die Wartung der betroffenen TPU-VM(s) durch.
- Nach abgeschlossener Wartung können Sie die TPU-VM wie gewohnt verwenden.
Verwenden Sie das Flag
autocheckpoint-enabled
, wenn Sie die TPU-VM erstellen oder eine Ressource aus der Warteschlange anfordern.Beispiel:
Legen Sie Umgebungsvariablen fest:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=zone-you-want-to-use export ACCELERATOR_TYPE=your-accelerator-type export RUNTIME_VERSION=tpu-ubuntu2204-base
Beschreibungen von Umgebungsvariablen
Variable Beschreibung PROJECT_ID
Ihre Projekt-ID in Google Cloud . Verwenden Sie ein vorhandenes oder erstellen Sie ein neues Projekt. TPU_NAME
Der Name der TPU. ZONE
Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen. ACCELERATOR_TYPE
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen. RUNTIME_VERSION
Die Softwareversion der Cloud TPU. Legen Sie Ihre Projekt-ID und Zone in Ihrer aktiven Konfiguration fest:
gcloud config set project $PROJECT_ID gcloud config set compute/zone $ZONE
Erstellen Sie eine TPU:
gcloud alpha compute tpus tpu-vm create $TPU_NAME \ --accelerator-type $ACCELERATOR_TYPE \ --version $RUNTIME_VERSION \ --autocheckpoint-enabled
Stellen Sie eine SSH-Verbindung zur TPU her:
gcloud compute tpus tpu-vm ssh $TPU_NAME
Installieren Sie Pax auf einem einzelnen Slice
Die Funktion Autocheckpoint ist in den Pax-Versionen 1.1.0 und höher verfügbar. Installieren Sie
jax[tpu]
und die aktuelle Version vonpaxml
auf der TPU-VM:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Konfigurieren Sie das Modell
LmCloudSpmd2B
. Bevor Sie das Trainingsskript ausführen, ändern SieICI_MESH_SHAPE
in[1, 8, 1]
:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 8, 1]
Starten Sie das Training mit der entsprechenden Konfiguration.
Im folgenden Beispiel wird gezeigt, wie Sie das
LmCloudSpmd2B
-Modell konfigurieren, um die von Autocheckpoint ausgelösten Prüfpunkte in einem Cloud Storage-Bucket zu speichern. Ersetzen Sie your-storage-bucket durch den Namen eines vorhandenen Buckets oder erstellen Sie einen neuen Bucket.export JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py \ --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Beachten Sie die beiden Flags, die an den Befehl übergeben werden:
jax_fully_async_checkpoint
: Wenn dieses Flag aktiviert ist, wirdorbax.checkpoint.AsyncCheckpointer
verwendet. Die KlasseAsyncCheckpointer
speichert automatisch einen Prüfpunkt, wenn das Trainingsskript ein SIGTERM-Signal empfängt.exit_after_ondemand_checkpoint
: Wenn dieses Flag aktiviert ist, wird der TPU-Prozess beendet, nachdem der automatische Prüfpunkt erfolgreich gespeichert wurde. Dadurch wird die Wartung sofort ausgeführt. Wenn Sie dieses Flag nicht angeben, wird das Training nach dem Speichern des Prüfpunkts fortgesetzt und die Cloud TPU wartet bis zum Auftreten einer Zeitüberschreitung (5 Minuten), bevor die erforderliche Wartung durchgeführt wird.
Autocheckpoint mit Orbax
Die Funktion Autocheckpoint ist nicht auf MaxText oder Pax beschränkt. Jedes Framework, das das SIGTERM-Signal erfassen und einen Prüfpunktprozess starten kann, funktioniert mit der von Autocheckpoint bereitgestellten Infrastruktur. Orbax, ein Namespace, der JAX-Nutzern allgemeine Utility-Bibliotheken bereitstellt, bietet diese Funktionen.
Wie in der Orbax-Dokumentation erläutert, sind diese Funktionen standardmäßig für Nutzer von orbax.checkpoint.CheckpointManager
aktiviert. Die save
-Methode, die nach jedem Schritt aufgerufen wird, prüft automatisch, ob ein Wartungsereignis bevorsteht. Wenn dies der Fall ist, wird ein Prüfpunkt gespeichert, auch wenn die Schrittnummer kein Vielfaches von save_interval_steps
ist.
In der GitHub-Dokumentation wird auch gezeigt, wie das Training nach dem Speichern eines Autocheckpoints beendet werden kann, indem der Nutzercode geändert wird.