Cloud TPU-Autocheckpoint [öffentliche Vorabversion]
Übersicht
Bisher wurde die Wartung einer TPU-VM sofort gestartet, ohne dass den Nutzern Zeit blieb, Aktionen auszuführen, die den Fortschritt sichern, z. B. das Speichern eines Checkpoints. Dies ist in Abbildung 1(a) dargestellt.
Abbildung 1 Abbildung der Funktion „Autocheckpoint“: (a) Ohne Autocheckpoint geht der Trainingsfortschritt seit dem letzten Checkpoint verloren, wenn ein Wartungsereignis bevorsteht. (b) Mit der automatischen Checkpoint-Funktion kann der Trainingsfortschritt seit dem letzten Checkpoint bei einem bevorstehenden Wartungsereignis beibehalten werden.
Mit dem automatischen Checkpoint (Abbildung 1(b)) können Sie den Trainingsfortschritt beibehalten, indem Sie Ihren Code so konfigurieren, dass ein nicht geplanter Checkpoint gespeichert wird, wenn ein Wartungsereignis auftritt. Wenn ein Wartungsereignis auftritt, wird der Fortschritt seit dem letzten Checkpoint automatisch gespeichert. Die Funktion funktioniert sowohl für einzelne Slices als auch für Multi-Slices.
Die Funktion „Autocheckpoint“ funktioniert mit Frameworks, die SIGTERM erfassen und anschließend einen Checkpoint speichern können. Zu den unterstützten Frameworks gehören MaxText, Pax und JAX mit Orbax. Unterstützung für weitere Frameworks wird bekannt gegeben, sobald sie verfügbar sind.
Derzeit können nur TPUs (v2–v4 und v5e), die über die Cloud TPU API erstellt wurden, diese Funktion nutzen. Die Unterstützung für TPUs in GKE wird angekündigt, sobald sie verfügbar ist.
Autocheckpoint verwenden
Die Funktion für automatische Checkpoints ist standardmäßig deaktiviert. Wenn Sie eine TPU oder eine in die Warteschlange gestellte Ressource erstellen, können Sie sie aktivieren, indem Sie bei der Bereitstellung der TPU das Flag --autocheckpoint-enabled
hinzufügen.
Wenn die Funktion aktiviert ist, führt Cloud TPU die folgenden Schritte aus, sobald eine Benachrichtigung zu einem Wartungsereignis eingegangen ist:
- SIGTERM erfassen, das über das TPU-Gerät an den Prozess gesendet wird,
- Wartet, bis der Prozess beendet ist oder 5 Minuten verstrichen sind, je nachdem, was zuerst eintritt, und führt Wartungsarbeiten an den betroffenen Slices durch.
Die von Autocheckpoint verwendete Infrastruktur ist unabhängig vom ML-Framework. Jedes ML-Framework kann Autocheckpoint unterstützen, sofern es das SIGTERM-Signal erfassen und einen Prüfpunktprozess initiieren kann.
Im Anwendungscode müssen Sie die vom ML-Framework bereitgestellten Funktionen für automatische Checkpoints aktivieren. In Pax bedeutet das beispielsweise, dass Befehlszeilen-Flags beim Starten des Trainings aktiviert werden müssen (siehe Schnellstart für automatische Checkpoints mit Pax). Im Hintergrund speichern die Frameworks einen nicht geplanten Checkpoint, wenn ein SIGTERM empfangen wird. Die betroffene TPU-VM wird dann gewartet, wenn die TPU nicht mehr verwendet wird.
Kurzanleitung: Autocheckpoint mit MaxText
MaxText ist ein „hochleistungsfähiger, beliebig skalierbarer, Open-Source-LLM, der in reiner Python/JAX geschrieben wurde und auf Cloud TPUs ausgerichtet ist“. MaxText enthält alle erforderlichen Einstellungen für die Verwendung der Funktion „Autocheckpoint“.
In der MaxText-Readme werden zwei Möglichkeiten zum Ausführen von MaxText im großen Maßstab beschrieben:
multihost_runner.py
wird verwendet, was für Tests empfohlen wirdmultihost_job.job
verwenden, für die Produktion empfohlen
Wenn Sie multihost_runner.py
verwenden, müssen Sie nur das Flag autocheckpoint-enabled
festlegen, wenn Sie die Ressourcen in der Warteschlange bereitstellen. Wenn Sie multihost_job.py
verwenden, müssen Sie beim Starten des Jobs nur das Befehlszeilenflag ENABLE_AUTOCHECKPOINT=true
angeben.
Kurzanleitung: Automatische Checkpoints mit Pax auf einzelnen Segmenten
In diesem Abschnitt findest du ein Beispiel für die Einrichtung und Verwendung von Autocheckpoint mit Pax auf einem einzelnen Slab. Bei entsprechender Einrichtung:
- Wenn ein Wartungsereignis auftritt, wird ein Checkpoint gespeichert.
- Nach dem Speichern des Checkpoints führt Cloud TPU Wartungsarbeiten an den betroffenen TPU-VMs durch.
- Sobald die Wartung von Cloud TPU abgeschlossen ist, können Sie die TPU-VM wie gewohnt verwenden.
Verwenden Sie das Flag
autocheckpoint-enabled
, wenn Sie die TPU-VM oder die Ressourcen in der Warteschlange erstellen.Beispiel:
PROJECT=your-gcp-project-name ZONE=zone-you-want-to-use NODE_ID=your-node-id ACCELERATOR_TYPE=your-accelerator-type gcloud config set project $PROJECT gcloud config set compute/zone $ZONE
gcloud alpha compute tpus tpu-vm create $NODE_ID \ --accelerator-type $ACCELERATOR_TYPE \ --version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Pax auf einem einzelnen Slice installieren
Die Funktion „Autocheckpoint“ funktioniert mit Pax-Versionen >= 1.1.0. Installieren Sie auf den TPU-VMs
jax[tpu]
und die neueste Version vonpaxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Training mit der entsprechenden Konfiguration starten
Im folgenden Beispiel wird gezeigt, wie Sie das
LmCloudSpmd2B
-Modell so konfigurieren, dass von „Autocheckpoint“ ausgelöste Checkpoints in einem Google Cloud Storage-Bucket gespeichert werden:JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Beachten Sie die beiden Flags, die an den Befehl übergeben werden:
jax_fully_async_checkpoint
: Wenn dieses Flag aktiviert ist, wirdorbax.checkpoint.AsyncCheckpointer
verwendet. Die KlasseAsyncCheckpointer
speichert automatisch einen Checkpoint, wenn das Trainingsscript ein SIGTERM-Signal empfängt.exit_after_ondemand_checkpoint
: Wenn dieses Flag aktiviert ist, werden die TPU-Prozesse beendet, nachdem der Autocheckpoint erfolgreich gespeichert wurde. Dadurch wird die Wartung sofort ausgeführt. Wenn Sie dieses Flag nicht verwenden, wird das Training nach dem Speichern des Checkpoints fortgesetzt und Cloud TPU wartet 5 Minuten, bevor die erforderliche Wartung durchgeführt wird.
Kurzanleitung: Autocheckpoint mit Pax bei Multislice
Die automatische Checkpoint-Funktion funktioniert nicht nur für einzelne Scheiben, sondern auch für Mehrere Scheiben. In diesem Abschnitt wird beschrieben, wie Sie automatische Checkpoints mit Multislice verwenden.
Geben Sie „Autocheckpoint“ beim Erstellen von Ressourcen in der Warteschlange an.
Eine Multi-Slice-Umgebung kann nur über eine in der Warteschlange befindliche Ressourcenanfrage bereitgestellt werden. Ähnlich wie beim Fall mit einer einzelnen Scheibe verwenden Sie das Flag
autocheckpoint-enabled
im Aufruf, um eine Ressourcenwarteschlange zu erstellen.QR_ID=your-qr-id NODE_COUNT=your-node-count ACCELERATOR_TYPE=your-accelerator-type gcloud compute tpus queued-resources create $QR_ID \ --node-count $NODE_COUNT \ --accelerator-type $ACCELERATOR_TYPE \ --runtime-version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Im Nutzerhandbuch für Multislice-Scans finden Sie Details zu allen verfügbaren Optionen. Sobald die angeforderte Ressource in der Warteschlange erstellt wurde und den Status
ACTIVE
hat, führen Sie die nächsten Schritte aus, um Pax mit Autocheckpoint auszuführen.Installieren Sie Pax auf allen VMs in der Multislice-Umgebung.
Installieren Sie auf den TPU-VMs
jax[tpu]
und die neuestepaxml
auf allen TPU-VMs in Ihrer Multislice-Umgebung:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Training mit der entsprechenden Konfiguration starten
In diesem Beispiel wird gezeigt, wie Sie das Modell
LmCloudSpmd2B
für automatische Checkpoints konfigurieren, wenn Sie in einer Multislice-Umgebung trainieren. Legen Sie vor dem Ausführen des Trainingsscripts DCN_MESH_SHAPE auf [2, 1, 1] fest, wie im folgenden Code gezeigt:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 4, 1] DCN_MESH_SHAPE = [2, 1, 1]
Beim Starten des Trainings sind zusätzlich zu den Befehlszeilenoptionen, die im Fall mit einer einzelnen Scheibe beschrieben wurden, noch drei weitere erforderlich:
num_hosts
: die Gesamtzahl der Hosts. In diesem Fall ist das 2.host_index
: Der Index des Hosts, der das Training startet. Sie variiert zwischen 0 undN-1
, wobeiN
die Gesamtzahl der Hosts ist.server_addr
: die IP-Adresse von Worker 0 von Knoten 0 mit einem nicht verwendeten Port (z. B. 8476). Verwenden Sie dazuhostname -i
auf Worker 0 von Knoten 0.
Automatischer Checkpoint mit Orbax
Die Funktion „Autocheckpoint“ ist nicht auf MaxText oder Pax beschränkt. Jedes Framework, das das SIGTERM-Signal erfassen und einen Checkpoint-Prozess initiieren kann, funktioniert mit der von Autocheckpoint bereitgestellten Infrastruktur. Diese Funktionen bietet Orbax, ein Namespace mit gängigen Dienstbibliotheken für JAX-Nutzer.
Wie in der Orbax-Dokumentation erläutert, sind diese Funktionen für Nutzer von orbax.checkpoint.CheckpointManager
standardmäßig aktiviert. Die Methode save
, die nach jedem Schritt aufgerufen wird, prüft automatisch, ob ein Wartungsereignis bevorsteht. Falls ja, wird ein Checkpoint gespeichert, auch wenn die Schrittnummer kein Vielfaches von save_interval_steps
ist.
In der GitHub-Dokumentation wird auch veranschaulicht, wie das Training nach dem Speichern eines automatischen Checkpoints beendet werden kann, indem der Nutzercode geändert wird.