Cloud TPU Multislice – Übersicht
Cloud TPU Multislice ist eine Full-Stack-Technologie zur Leistungsskalierung. mit dem ein Trainingsjob mehrere TPU-Slices in einem einzelnen Pod oder auf in mehreren Pods mit einfacher Datenparallelität. Bei TPU v4-Chips bedeutet, dass Trainingsjobs mehr als 4.096 Chips in einem einzigen Durchlauf verwenden können. Bei Trainingsjobs, für die weniger als 4.096 Chips erforderlich sind, kann ein einzelner Sliver die beste Leistung bieten. Mehrere kleinere Scheiben sind jedoch leichter verfügbar, was zu einer kürzeren Startzeit führt, wenn Multislice mit kleineren Scheiben verwendet wird.
Bei der Bereitstellung in Multislice-Konfigurationen kommunizieren die TPU-Chips in jedem Slice über Inter-Chip-Interconnects (ICI). TPU-Chips in verschiedenen Slices kommunizieren, indem sie Daten an CPUs (Hosts) übertragen, die die Daten wiederum über das Rechenzentrumsnetzwerk (DCN) weiterleiten.
Entwickler müssen keinen Code schreiben, um die DCN-Kommunikation zwischen Slices zu implementieren. Der XLA-Compiler generiert diesen Code für Sie und überschneidet die Kommunikation mit der Berechnung, um die Leistung zu maximieren.
Konzepte
- Beschleunigertyp
- Die Form jedes TPU-Slices, das ein Multislice bildet. Jedes
in einer Anfrage mit mehreren Segmenten denselben Beschleunigertyp hat. Ein Beschleunigertyp besteht aus einem TPU-Typ (v4 oder v5e) gefolgt von der Anzahl der Tensorkerne. Beispielsweise gibt
v4-128
eine TPU v4 mit 128 TensorCores an. - Automatische Reparatur
- Wenn bei einem Slice ein Wartungsereignis, eine vorzeitige Beendigung oder ein Hardwarefehler auftritt, Cloud TPU erstellt ein neues Slice. In dem seltenen Fall, dass nicht genügend Ressourcen für die Erstellung eines neuen Slices vorhanden sind, wird die Erstellung erst abgeschlossen, wenn Hardware verfügbar ist. Nachdem das neue Segment erstellt wurde, in der Multislice-Umgebung neu gestartet. Mit einem korrekt konfigurierten Startskript können ohne Eingriff des Nutzers automatisch neu gestartet werden. ab dem letzten Checkpoint.
- Dataset
- Die Daten, die von einem Modell für Training oder Inferenz verwendet werden.
- Rechenzentrumsnetzwerk (DCN)
- Ein Netzwerk mit höherer Latenz und geringem Durchsatz (im Vergleich zu ICI), das verbindet TPU-Slices in einer Multislice-Konfiguration.
- Gruppenplanung
- Wenn alle TPU-Slices gemeinsam bereitgestellt werden, Entweder wurden alle oder keine der Slices bereitgestellt.
- Moderator:in
- Ein Host ist ein physischer Computer, auf dem VMs ausgeführt werden. Auf einem Host können maximal vier VMs gleichzeitig ausgeführt werden. Jede VM hat eine dedizierte TPU.
- Inferenz
- Laden Sie ein vortrainiertes Modell für maschinelles Lernen auf einen Host und treffen Sie Vorhersagen für Daten.
- Interchip Interconnect (ICI)
- Interne Hochgeschwindigkeitsverbindungen mit niedriger Latenz, die TPUs innerhalb eines TPU-Pods verbinden.
- Mehrfachschicht
- Zwei oder mehr TPU-Chip-Slices, die über das DCN kommunizieren können.
- Knoten
- Im Multislice-Kontext bezieht sich „Knoten“ auf ein einzelnes TPU-Stück. Jedes TPU-Slices in einem Multislice wird eine Knoten-ID zugewiesen.
- Pod
- Eine Sammlung von TPU-Chips, die über dedizierte ICI-Netzwerkschnittstellen verbunden sind. A Mit einem Pod können Sie die Verarbeitungslast auf mehrere TPUs verteilen.
- Ressource in der Warteschlange (QR)
- Eine Darstellung von TPU-Ressourcen, mit denen eine Anfrage für ein Element in die Warteschlange gestellt und verwaltet wird Einzel- oder Multislice-TPU-Umgebung.
- Startskript
- Ein standardmäßiges Compute Engine-Startskript die jedes Mal ausgeführt wird, wenn eine VM gestartet oder neu gestartet wird. Bei Multislice wird sie in der Anfrage zum Erstellen des QR-Codes angegeben. Weitere Informationen Informationen zu Cloud TPU-Startskripts finden Sie unter TPU-Ressourcen verwalten.
- TPU-Slice
- Logischer Teilbereich eines TPU-Pods, der aus TPU-Chips besteht. Alle Chips in einem Slice kommunizieren über das ICI-Netzwerk miteinander.
- TPU-VM
- Eine virtuelle Maschine mit Linux, die Zugriff auf die zugrunde liegenden TPUs hat. Für v4-TPUs hat jede TPU-VM direkten Zugriff auf vier Chips. Manchmal wird eine TPU-VM auch als Worker bezeichnet.
- Tensor
- Eine Datenstruktur, die zum Darstellen mehrdimensionaler Daten in einem Modell für maschinelles Lernen verwendet wird.
- Tensor Processing Unit (TPU)
- Der intern entwickelte Chip für die ML-Beschleunigung von Google. Sie sind für schnelle und energieeffiziente Berechnungen bei wichtigen Aufgaben des maschinellen Lernens wie der Matrixmultiplikation ausgelegt.
- Arten von Cloud TPU-Kapazitäten
TPUs können aus verschiedenen Arten von Kapazität erstellt werden (siehe „Nutzungsoptionen“ im Artikel So funktionieren die TPU-Preise):
- Reservierung: Targeting auf reserviertes Kontingent. Zur Verwendung reservierter Kontingente benötigen Sie eine
Reservierungsvereinbarung mit Google. Verwenden Sie beim Erstellen das Flag
--reserved
. Ihre Ressourcen. - Spot: Targeting auf Kontingent auf Abruf mithilfe von Spot-VMs. Ihr
werden unter Umständen vorzeitig beendet, um Platz für Anfragen für eine höhere
priorisierten Job. Verwenden Sie beim Erstellen Ihrer Ressourcen das Flag
--spot
. - On-Demand: Gezielt auf ein On-Demand-Kontingent, das keine Reservierung erfordert und nicht vorzeitig beendet. Die TPU-Anfrage wird in eine On-Demand- von Cloud TPU angebotene Kontingentwarteschlange, ist die Verfügbarkeit von Ressourcen garantiert. Standardmäßig ausgewählt, keine Flags erforderlich.
- Reservierung: Targeting auf reserviertes Kontingent. Zur Verwendung reservierter Kontingente benötigen Sie eine
Reservierungsvereinbarung mit Google. Verwenden Sie beim Erstellen das Flag
Mehr erfahren
Wenn Sie noch keine TPUs verwendet haben, installieren Sie zuerst die Google Cloud CLI. und richten Sie Ihre Cloud TPU-Umgebung ein. Um Multislice verwenden, müssen Ihre TPU-Ressourcen als Ressourcen in der Warteschlange verwaltet werden.
Wenn Sie bereits TPU v4 verwenden und eine Reservierung haben, müssen Sie diese möglicherweise in ein neues Reservierungssystem migrieren. Weitere Informationen wenden Sie sich an Ihren Google Cloud-Kundenbetreuer.
Beispiel für die Einführung
In dieser Anleitung wird Code aus dem GitHub-Repository von MaxText verwendet. MaxText ist ein leistungsstarkes, beliebig skalierbares Open-Source-Projekt, das sich umfassend getestet hat in Python und Jax geschrieben. MaxText wurde für ein effizientes Training auf Cloud TPU
Der Code in shardings.py
soll Ihnen den Einstieg in verschiedene Parallelisierungsoptionen erleichtern. Zum Beispiel Datenparallelität, vollständig fragmentierte Datenparallelität (Full Sharded Data Parallelism, FSDP),
und Tensor-Parallelität. Der Code kann von einer einzelnen Schleife auf Multislice-Umgebungen skaliert werden.
ICI-Parallelität
ICI bezieht sich auf die Hochgeschwindigkeitsverbindung, die die TPUs in einem einzelnen Slice verbindet. ICI-Sharding entspricht dem Sharding innerhalb eines Segments. shardings.py
stellt drei ICI-Parallelitätsparameter bereit:
ici_data_parallelism
ici_fsdp_parallelism
ici_tensor_parallelism
Die Werte, die Sie für diese Parameter angeben, bestimmen die Anzahl der Shards für jede Parallelisierungsmethode.
Diese Eingaben müssen so eingeschränkt sein, dass ici_data_parallelism * ici_fsdp_parallelism * ici_tensor_parallelism
der Anzahl der Chips im Slice entspricht.
Die folgende Tabelle zeigt Beispielnutzereingaben für die ICI-Parallelität für die vier In v4-8 verfügbare Chips:
ici_data_parallelism | ici_fsdp_parallelism | ici_tensor_parallelism | |
4-Wege-FSDP | 1 | 4 | 1 |
4-Wege-Tensor-Parallelität | 1 | 1 | 4 |
2-Wege-FSDP + 2-Wege-Tensor-Parallelität | 1 | 2 | 2 |
Hinweis: ici_data_parallelism
sollte in den meisten Fällen bei 1 belassen werden, da das ICI-Netzwerk schnell genug ist, um fast immer FSDP dem Datenparallelismus vorzuziehen.
In diesem Beispiel wird davon ausgegangen, dass Sie mit dem Ausführen von Code auf einem einzelnen TPU-Speicherbereich vertraut sind, z. B. wie im Artikel Berechnung mit JAX auf einer Cloud TPU-VM ausführen.
Dieses Beispiel zeigt, wie shardings.py
für ein einzelnes Segment ausgeführt wird.
Richten Sie die Umgebung ein:
$ gcloud auth login $ gcloud config set project your-project-id $ gcloud config set compute/zone your-zone
Erstellen Sie SSH-Schlüssel für
gcloud
. Wir empfehlen, das Passwort leer zu lassen. Drücken Sie dazu nach Ausführung des folgenden Befehls zweimal die Eingabetaste. Wenn Sie aufgefordert werden, die Dateigoogle_compute_engine
zu ersetzen, weil sie bereits vorhanden ist, ersetzen Sie die vorhandene Version.$ ssh-keygen -f ~/.ssh/google_compute_engine
Stellen Sie Ihre TPUs mit dem folgenden Befehl bereit:
$ gcloud compute tpus queued-resources \ create your-qr-id \ --accelerator-type your-accelerator-type \ --runtime-version tpu-ubuntu2204-base \ --node-id qr-id \ [--reserved |--spot]
Beschreibung der Befehls-Flags
your-qr-id
- Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
accelerator-type
- Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen.
runtime-version
- Die [Cloud TPU-Softwareversion](/tpu/docs/supported-tpu-configurations#tpu_software_versions).
node-id
- Die ID der TPU-Ressourcen, die als Reaktion auf die QR-Anfrage.
reserved
- Verwenden Sie beim Erstellen der Slices ein reserviertes Kontingent.
spot
- Verwenden Sie beim Erstellen der Slices das Kontingent für Spot-VMs.
Die Google Cloud CLI unterstützt nicht alle Optionen zum Erstellen von QR-Codes, z. B. Tags. Weitere Informationen finden Sie unter QR-Codes erstellen.
Warten Sie, bis sich der QR-Code im Status
ACTIVE
befindet. Das bedeutet, dass sich die Worker-Knoten den StatusREADY
haben. Sobald die QR-Bereitstellung gestartet wurde, kann es je nach Größe des QR-Codes ein bis fünf Minuten dauern, bis sie abgeschlossen ist. Sie können den Status mit dem folgenden Befehl erstellen:$ gcloud compute tpus queued-resources \ list --filter=your-qr-id
Ein v4-8-Slice hat eine einzelne TPU-VM. Stellen Sie eine SSH-Verbindung zur TPU-VM her:
$ gcloud compute tpus tpu-vm ssh your-qr-id
Klonen Sie MaxText (einschließlich
shardings.py
) auf Ihre TPU-VM.Führen Sie im MaxText-Repository-Verzeichnis das Einrichtungsskript aus, um JAX und andere Abhängigkeiten auf Ihrem TPU-Speichereinsatz zu installieren. Das Einrichtungsskript Minuten gelaufen.
$ bash setup.sh
Führen Sie den folgenden Befehl aus, um
shardings.py
auf Ihrem TPU-Speicherbereich auszuführen.$ python3 pedagogical_examples/shardings.py \ --ici_fsdp_parallelism 4 \ --batch_size 131072 \ --embedding_dimension 2048
Sie können die Ergebnisse in den Logs sehen. Ihre TPUs sollten etwa 260 TFLOP/Sekunde erreichen, was einer beeindruckenden FLOP-Auslastung von über 90 % entspricht. In diesem Fall haben wir ungefähr den maximalen Batch ausgewählt, der in den hohen Wert der TPU passt Bandwidth Memory (HBM) (Bandbreitenspeicher).
Sie können auch andere Sharding-Strategien über ICI ausprobieren. Probieren Sie beispielsweise die folgende Kombination aus:
$ python3 pedagogical_examples/shardings.py \ --ici_tensor_parallelism 4 \ --batch_size 131072 \ --embedding_dimension 2048
Löschen Sie die QR- und TPU-Scheiben, wenn Sie fertig sind. Führen Sie diese Schritte zur Bereinigung in der Umgebung aus, in der Sie das Slice eingerichtet haben. Führen Sie zuerst
exit
aus, um die SSH-Sitzung zu beenden. Das Löschen dauert zwei bis fünf Minuten und kann mit dem optionalen Flag--async
im Hintergrund ausgeführt werden.$ gcloud compute tpus queued-resources delete your-qr-id --force (--async)
Mehrfachscheiben-Sharding mit DCN-Parallelität
Das Skript shardings.py
verwendet drei Parameter, die die DCN-Parallelität festlegen.
entspricht der Anzahl der Shards jedes Typs von Datenparallelität:
- dcn_data_parallelism
- dcn_fsdp_parallelism
- dcn_tensor_parallelism
Die Werte dieser Parameter müssen so eingeschränkt sein, dass dcn_data_parallelism * dcn_fsdp_parallelism * dcn_tensor_parallelism
der Anzahl der Scheiben entspricht.
Verwenden Sie als Beispiel für zwei Segmente --dcn_data_parallelism = 2
.
dcn_data_parallelism | dcn_fsdp_parallelism | dcn_tensor_parallelism | Anzahl der Segmente | |
Bidirektionale Datenparallelität | 2 | 1 | 1 | 2 |
dcn_tensor_parallelism
sollte immer auf 1
gesetzt werden, da das DCN schlecht ist.
für eine solche Fragmentierung geeignet. Bei typischen LLM-Arbeitslasten auf v4-Chips
dcn_fsdp_parallelism
sollte auch auf 1
gesetzt werden, damit
dcn_data_parallelism
sollte auf die Anzahl der Segmente festgelegt werden,
anwendungsabhängig.
Wenn Sie die Anzahl der Segmente erhöhen (vorausgesetzt, Sie halten die Segmentgröße und den Batch pro Segment konstant), erhöhen Sie die Datenparallelität.
shardings.py
in einer Multislice-Umgebung ausführen
Sie können shardings.py
in einer Umgebung mit mehreren Segmenten mit
multihost_runner.py
oder durch Ausführen von shardings.py
auf jeder TPU-VM. Hier verwenden wir multihost_runner.py
. Die folgenden Schritte ähneln denen unter Einstieg: Schnelle Tests für mehrere Segmente im MaxText-Repository. Der Unterschied besteht darin, dass hier shardings.py
anstelle des komplexeren LLM in train.py
verwendet wird.
Das multihost_runner.py
-Tool ist für schnelle Tests optimiert und verwendet wiederholt dieselben TPUs. Da das Skript multihost_runner.py
von
langlebigen SSH-Verbindungen verwenden möchten, empfehlen wir diese Methode nicht für Jobs mit langer Ausführungszeit.
Wenn Sie einen längeren Job ausführen möchten (z. B. Stunden oder Tage), empfehlen wir Ihnen,
multihost_job.py
verwenden.
In dieser Anleitung geben wir mit dem Begriff runner die Maschine an, auf der Sie
führen Sie das Skript multihost_runner.py
aus. Wir verwenden den Begriff Worker für
TPU-VMs, aus denen Ihre Slices bestehen. Sie können multihost_runner.py
auf einem lokalen
oder eine beliebige Compute Engine-VM
im selben Projekt wie Ihre Slices. Laufen
multihost_runner.py
für einen Worker wird nicht unterstützt.
multihost_runner.py
stellt automatisch eine SSH-Verbindung zu TPU-Workern her.
In diesem Beispiel wird shardings.py
über zwei v4-16-Segmente ausgeführt, insgesamt vier VMs und 16 TPU-Chips. Sie können das Beispiel so ändern, dass es auf mehr TPUs ausgeführt wird.
Umgebung einrichten
MaxText im Runner klonen Maschine.
Rufen Sie das Repository-Verzeichnis auf.
Erstellen Sie SSH-Schlüssel für
gcloud
. Wir empfehlen, das Passwort leer zu lassen. Drücken Sie nach Ausführung des folgenden Befehls zweimal die Eingabetaste. Wenn Sie aufgefordert werden, Die Dateigoogle_compute_engine
ist bereits vorhanden. Wählen Sie aus, dass Sie Ihre vorhandenen Version.$ ssh-keygen -f ~/.ssh/google_compute_engine
Fügen Sie eine Umgebungsvariable hinzu, um die Anzahl der TPU-Slices auf
2
festzulegen.$ export SLICE_COUNT=2
Erstellen Sie mit
queued-resources create
eine Multislice-Umgebung.Der folgende Befehl zeigt, wie Sie eine v4-Multislice-TPU erstellen. Um v5e einen
accelerator-type
für v5e (z. B.v5litepod-16
) und den Version 5eruntime-version
(v2-alpha-tpuv5-lite
).$ gcloud compute tpus queued-resources
create your-qr-id
--accelerator-type=your-accelerator-type
--runtime-version=tpu-vm-runtime-version
--node-count=node-count
--node-prefix=your-qr-id
[--reserved|--spot]Beschreibung der Befehls-Flags
your-qr-id
- Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
accelerator-type
- Mit dem Beschleunigertyp geben Sie die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen.
runtime-version
- Die Cloud TPU-Softwareversion.
node-count
- Die Anzahl der zu erstellenden Segmente.
node-prefix
- Das Präfix, mit dem Namen für die einzelnen Segmente generiert werden. Eine Zahl wird angehängt
bis zum Präfix hinzu. Wenn Sie beispielsweise
node-prefix
aufmySlice
festlegen, werden die SlicesmySlice-0
,mySlice-1
usw. benannt. reserved
- Verwenden Sie beim Erstellen der Slices ein reserviertes Kontingent.
spot
- Verwenden Sie beim Erstellen der Slices das Kontingent für Spot-VMs.
Wenn die QR-Bereitstellung gestartet wird, kann es bis zu fünf Minuten dauern, bis die Bereitstellung abgeschlossen ist je nach Größe des QR-Codes. Warten Sie, bis sich die Ressource in der Warteschlange (QR) in der Warteschlange befindet
ACTIVE
. Sie können den Status einer QR-Anfrage mit der folgenden Befehl:$ gcloud compute tpus queued-resources list \ --filter=your-qr-id
Die Ausgabe sollte in etwa so aussehen:
NAME ZONE NODE_COUNT ACCELERATOR_TYPE STATE ... que-res-id us-central2-b 4 v4-16 ACTIVE ...
Wenden Sie sich an Ihren Google Cloud-Kundenbetreuer, wenn der QR-Status in der
WAITING_FOR_RESOURCES
- oderPROVISIONING
-Status für mehr als 15 Minuten.Installieren Sie die Abhängigkeiten:
$ python3 multihost_runner.py \ --TPU_PREFIX=your-qr-id \ --COMMAND="bash setup.sh"
Führen Sie
shardings.py
mitmultihost_runner.py
auf jedem Worker aus.$ python3 multihost_runner.py \ --TPU_PREFIX=your-qr-id \ --COMMAND="python3 pedagogical_examples/shardings.py \ --dcn_data_parallelism $SLICE_COUNT \ --ici_fsdp_parallelism 8 \ --batch_size 131072 \ --embedding_dimension 2048"
Im Protokoll werden ungefähr 230 TFLOPs pro Sekunde der Leistung angezeigt. -Dateien.
Bereinigen Sie die TPUs und den QR-Code, wenn Sie fertig sind. Das Löschen dauert zwei bis fünf Minuten und kann mit dem optionalen Flag
--async
im Hintergrund ausgeführt werden.
Arbeitslast auf Multislice skalieren
Bevor Sie Ihr Modell in einer Multislice-Umgebung ausführen, nehmen Sie die folgenden Codeänderungen vor:
- Verwenden Sie beim Erstellen des Mesh jax.experimental.mesh_utils.create_hybrid_device_mesh anstelle von jax.experimental.mesh_utils.create_device_mesh.
Dies sollten die einzigen erforderlichen Codeänderungen sein, die beim Wechsel zur Multislice-Funktion erforderlich sind. Um eine hohe Leistung zu erzielen, muss das DCN parallel und vollständig den Daten zugeordnet werden. parallele Achsen fragmentierter Daten oder parallele Pipelineachsen. Leistungsaspekte und Sharding-Strategien werden im Artikel Sharding mit Multislice für maximale Leistung ausführlicher behandelt.
Um zu überprüfen, ob Ihr Code auf alle Geräte zugreifen kann, können Sie bestätigen,
len(jax.devices())
entspricht der Anzahl der Chips im Multislice
zu verbessern. Wenn Sie beispielsweise vier Scheiben von v4-16
verwenden, haben Sie acht Chips pro Scheibe × 4 Scheiben, sodass len(jax.devices())
den Wert 32 zurückgeben sollte.
Scheibengrößen für Multi-Scheiben-Umgebungen auswählen
Wenn Sie eine lineare Beschleunigung erzielen möchten, fügen Sie neue Segmente mit derselben Größe wie das vorhandene Segment hinzu. Wenn Sie beispielsweise einen v4-512
-Speicherplatz verwenden, kann Multislice die Leistung ungefähr verdoppeln, indem ein zweiter v4-512
-Speicherplatz hinzugefügt und die globale Batchgröße verdoppelt wird. Weitere Informationen finden Sie unter Sharding mit Multislice für maximale Leistung.
Job auf mehreren Slices ausführen
Es gibt drei verschiedene Ansätze, um Ihre benutzerdefinierte Arbeitslast in einem Umgebung mit mehreren Segmenten:
- Mit dem Test-Runner-Script
multihost_runner.py
- Mit dem Produktions-Runner-Skript
multihost_job.py
- Manueller Ansatz
Script für Testausführung
Die multihost_runner.py
Das Skript verteilt den Code an eine vorhandene Multislice-Umgebung und führt
Ihren Befehl auf jedem Host ausführen, Ihre Protokolle zurückkopieren und die Fehler der einzelnen Befehle nachverfolgen.
Status. Das multihost_runner.py
-Script ist in der MaxText-README dokumentiert.
Da multihost_runner.py
persistente SSH-Verbindungen unterhält, ist es nur
eignet sich für mittelgroße, relativ kurze Tests. Sie können
Passen Sie die Schritte in der multihost_runner.py-Anleitung an.
an Ihre Arbeitslast und Hardwarekonfiguration anpassen.
Skript für die Produktionsausführung
Für Produktionsjobs, die Resilienz gegen Hardwarefehler und andere
vorzeitigem Beenden zu deaktivieren, sollten Sie dies am besten direkt in die
der API erstellen. Als Beispiel stellen wir multihost_job.py
bereit,
löst den API-Aufruf Created Queued Resource API mit dem entsprechenden Startvorgang aus.
um das Training auszuführen und bei vorzeitigem Beenden fortzusetzen. Das multihost_job.py
Skript im
README-Datei für MaxText.
Da multihost_job.py
für jede Ausführung Ressourcen bereitstellen muss, wird dies nicht
einen so schnellen Iterationszyklus wie multihost_runner.py
bereitstellen.
Manueller Ansatz
Wir empfehlen, multihost_runner.py zu verwenden oder anzupassen oder multihost_job.py, um Ihre benutzerdefinierte Arbeitslast in Ihre Multislice-Konfiguration. Wenn Sie jedoch eine Bereitstellung und können Sie Ihre Umgebung direkt über QR-Befehle verwalten, siehe Multislice-Umgebung verwalten
Multislice-Umgebung verwalten
So können Sie QR-Codes manuell bereitstellen und verwalten, ohne die Tools zu verwenden im MaxText-Repository bereitgestellt haben, lesen Sie die folgenden Abschnitten.
QR-Codes erstellen
Legen Sie die folgenden Umgebungsvariablen fest, bevor Sie Kapazität bereitstellen:
$ export your-qr-id=your-queued-resource-id $ export PROJECT=your-project-name $ export ZONE=us-central2-b $ export NETWORK_NAME=your-network-name $ export SUBNETWORK_NAME=your-subnetwork-name $ export RUNTIME_VERSION=tpu-ubuntu2204-base $ export ACCELERATOR_TYPE=v4-16 $ export SLICE_COUNT=4 $ export STARTUP_SCRIPT="#!/bin/bash\n ..." $ gcloud config set project project-name $ gcloud config set compute/zone zone
Eingabe | Beschreibung |
your-qr-id | Die vom Nutzer zugewiesene ID des QR-Codes. |
PROJECT | Name des Google Cloud-Projekts |
ZONE | us-central2-b |
NETWORK_NAME | Name der VPC-Netzwerke. |
SUBNETWORK_NAME | Name des Subnetzes in VPC-Netzwerken |
RUNTIME_VERSION | tpu-Ubuntu2204-base |
ACCELERATOR_TYPE | v4-16 |
Bsp_TAG_1, Bsp_TAG_2… | Tags, mit denen gültige Quellen oder Ziele für Netzwerk-Firewalls angegeben werden |
SLICE_COUNT | Anzahl der Segmente. Begrenzt auf maximal 256 Segmente. |
STARTUP_SCRIPT | Falls zur Erstellungsanfrage hinzugefügt: Ein Startskript kann immer dann ausgeführt werden, wenn ein TPU-Slice bereitgestellt oder neu gestartet wird. und ob das TPU-Slice repariert oder zurückgesetzt wurde. |
QR-Anfrage mit gcloud
erstellen
$ gcloud compute tpus queued-resources \ create ${your-qr-id} \ --project your-project-id \ --zone your-zone \ --node-count ${SLICE_COUNT} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --network ${NETWORK_NAME} \ --subnetwork ${SUBNETWORK_NAME} \ --tags ${EXAMPLE_TAG_1},${EXAMPLE_TAG_2} \ --metadata=startup-script='${STARTUP_SCRIPT}' [--reserved|--spot]
Beschreibung der Befehls-Flags
your-qr-id
- Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
project
- Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
zone
- Die Google Cloud-Zone, in der der QR-Code erstellt werden soll.
node-count
- Die Anzahl der zu erstellenden Segmente.
accelerator-type
- Mit dem Beschleunigertyp geben Sie die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen.
runtime-version
- Die Version der Cloud TPU-Software.
network
- Der Name eines VPC-Netzwerk, an das die TPU-Ressource angehängt werden soll.
subnetwork
- Der Name eines VPC-Subnetzwerks, an das die TPU-Ressource angehängt werden soll.
reserved
- Verwenden Sie beim Erstellen der Slices ein reserviertes Kontingent.
spot
- Verwenden Sie beim Erstellen der Slices das Kontingent für Spot-VMs.
Achten Sie darauf, dass Sie das entsprechende Kontingent haben, bevor Sie --reserved
, --spot
oder das Standardkontingent für On-Demand-Anzeigeaufträge auswählen. Informationen zu Kontingenttypen
Siehe Kontingentrichtlinie.
QR-Anfrage mit curl
erstellen
Erstellen Sie eine Datei mit dem Namen queued-resource-req.json
und kopieren Sie den folgenden JSON-Code hinein.
{ "guaranteed": { "reserved": true }, "tpu": { "node_spec": [ { "parent": "projects/your-project-number/locations/your-zone", "node": { "accelerator_type": "accelerator-type", "runtime_version": "tpu-vm-runtime-version", "network_config": { "network": "your-network-name", "subnetwork": "your-subnetwork-name", "enable_external_ips": true }, "tags" : ["example-tag-1"] "metadata": { "startup-script": "your-startup-script" } }, "multi_node_params": { "node_count": slice-count, "node_id_prefix": "your-queued-resource-id" } } ] } }
- your-project-number: Ihre Google Cloud-Projektnummer
- your-zone – die Zone, in der Sie den QR-Code erstellen möchten
- accelerator-type – Version und Größe eines einzelnen Slice
- tpu-vm-runtime-version – die Laufzeitversionen der TPU-VM
- your-network-name – Optional ein Netzwerk, an das der QR-Code angehängt wird
- your-subnetwork-name – Optional: ein Subnetzwerk, an das der QR-Code angehängt werden soll
- example-tag-1 – Optionaler, beliebiger Tag-String
- your-startup-script – ein Startskript, das ausgeführt wird, wenn der QR-Code zugewiesen wird
- slice-count – Die Anzahl der TPU-Slices in Ihrer Multislice-Umgebung
- your-qr-id – Die vom Nutzer angegebene ID für den QR-Code
Weitere Informationen zu allen verfügbaren Optionen findest du in der Dokumentation zur REST Queued Resource API.
Wenn Sie die Spot-Kapazität verwenden möchten, ersetzen Sie Folgendes:
"guaranteed": { "reserved": true }
mit "spot": {}
Entfernen Sie die Zeile, um die Standardkapazität für On-Demand-Kapazitäten zu verwenden.
Senden Sie die Anfrage zum Erstellen des QR-Codes mit der JSON-Nutzlast:
$ curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" -d @queuedresourcereq.json https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/queuedResources\?queued_resource_id\=your-qr-id
- your-project-id – Ihre Google Cloud-Projekt-ID
- your-zone – die Zone, in der Sie den QR-Code erstellen möchten
- your-qr-id – Die vom Nutzer angegebene ID für den QR-Code
Die Antwort sollte in etwa so aussehen:
{ "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qr-guid>", "metadata": { "@type": "type.googleapis.com/google.cloud.common.OperationMetadata", "createTime": "2023-11-01T00:17:05.742546311Z", "target": "projects/<your-project-id>/locations/<your-zone>/queuedResources/<your-qa-id>", "verb": "create", "cancelRequested": false, "apiVersion": "v2alpha1" }, "done": false }
Verwenden Sie den GUID-Wert am Ende des Stringwerts für das Attribut name
, um Informationen zur QR-Anfrage zu erhalten.
Status eines QR-Codes abrufen
Verwenden Sie den folgenden Befehl, um den Status der QR-Anfrage abzurufen:
$ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/your-project-id/locations/your-zone/operations/operation-your-qr-guid
- your-project-id – Ihre Google Cloud-Projekt-ID
- your-zone – Der Bereich, in dem der QR-Code erstellt werden soll.
- your-qr-guid: Die GUID, die in der Ausgabe der Anfrage zum Erstellen des QR-Codes auf
name
folgt.
Die Antwort auf diesen Befehl enthält den Status des Vorgangs:
{ "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qa-guid>, "metadata": {...}, "done": true, "response": { "@type": "type.googleapis.com/google.cloud.tpu.v2.QueuedResource", ... "state": { "state": "WAITING_FOR_RESOURCES" } } }
Wenn der QR-Code erstellt wurde, ("done = true")
, wird der Status innerhalb der
Das Feld response
ist entweder WAITING_FOR_RESOURCES
oder FAILED
.
Wenn der QR-Code den Status WAITING_FOR_RESOURCES
hat, wurde er
und mit der Bereitstellung beginnen, sobald genügend Ressourcen vorhanden sind. Wenn der QR-Code den Status FAILED
hat, wird der Grund für den Fehler in der Ausgabe angezeigt. Weitere Informationen
Informationen zu anderen möglichen Status finden Sie in der
Nutzerhandbuch für Ressourcen in der Warteschlange.
Sobald der Vorgang abgeschlossen ist, verwenden Sie die Kurzantworten beschreiben. um die Phasen des QR-Codes zu überwachen.
In seltenen Fällen kann es vorkommen, dass Ihr QR-Code den Status FAILED
hat, während einige
Slices ACTIVE
. Löschen Sie in diesem Fall die erstellten Ressourcen
und versuche es in ein paar Minuten noch einmal oder wende dich an uns
an das Cloud TPU-Team senden, um das Problem zu beheben.
SSH und Abhängigkeiten installieren
Im Hilfeartikel JAX-Code auf TPU-Pod-Slices ausführen wird beschrieben, wie Sie in einem einzelnen Slice eine SSH-Verbindung zu Ihren TPU-VMs herstellen. Bis
eine Verbindung zu allen TPU-VMs in Ihrer Multislice-Umgebung über SSH und
um Abhängigkeiten zu installieren, verwenden Sie den folgenden gcloud
-Befehl:
$ gcloud compute tpus queued-resources ssh ${your-qr-id} \ --zone your-zone \ --node=all \ --worker=all \ --command="command-to-run" --batch-size=4
Mit diesem gcloud
-Befehl wird der angegebene Befehl an alle Worker und Knoten in
QR-Code mit SSH Der Befehl wird in Gruppen von vier zusammengefasst und gleichzeitig gesendet. Der nächste Batch von Befehlen wird gesendet, wenn der aktuelle Batch
um die Ausführung abzuschließen. Wenn bei einem der Befehle ein Fehler auftritt, wird die Verarbeitung beendet und es werden keine weiteren Batches gesendet. Weitere Informationen finden Sie in der
API-Referenz für Ressourcen in der Warteschlange
Wenn die Anzahl der verwendeten Slices das Threading des lokalen Computers überschreitet
(auch Batch-Limit genannt),
besteht ein Deadlock. Hier ein Beispiel:
davon ausgehen, dass das Batchlimit auf Ihrem lokalen Rechner bei 64 liegt. Wenn Sie versuchen,
auf mehr als 64 Slices, z. B. 100 Slices, den SSH-Befehl
in Batches aufteilen. Das Trainingsskript wird beim ersten Batch von 64
und warten Sie, bis die Skripte abgeschlossen sind, bevor Sie sie auf dem
verbleibenden Batch von 36 Segmenten. Der erste Batch von 64 Segmenten kann jedoch nicht
bis die verbleibenden 36 Slices beginnen,
das Skript auszuführen, was zu einer
Deadlock.
Um dies zu verhindern, können Sie das Trainingsskript im Hintergrund auf
jede VM durch Anhängen eines kaufmännischen Und-Zeichens (&
) an den von Ihnen angegebenen Skriptbefehl
mit dem Flag --command
. Wenn Sie dies tun, wird die Steuerung nach dem Starten des Trainingsscripts für die erste Gruppe von Segmenten sofort an den SSH-Befehl zurückgegeben. Der SSH-Befehl kann
dann das Trainingsskript auf
der restlichen 36 Segmente. Sie müssen stdout
und stderr
streamt sie ordnungsgemäß, wenn sie im Hintergrund ausgeführt wird. Zum Erhöhen
Parallelität innerhalb desselben QR-Codes können Sie über die --node
-Schaltfläche bestimmte Segmente auswählen.
.
Netzwerkeinrichtung
Führen Sie die folgenden Schritte aus, um sicherzustellen, dass TPU-Slices miteinander kommunizieren können.
Installieren Sie JAX auf allen Slices. Weitere Informationen finden Sie unter JAX-Code auf TPU-Pod-Slices ausführen. Behaupten, dass
len(jax.devices())
entspricht der Anzahl der Chips im Multislice
zu verbessern. Führen Sie dazu für jedes Segment folgenden Befehl aus:
$ python3 -c 'import jax; print(jax.devices())'
Wenn Sie diesen Code auf vier Abschnitten von v4-16 ausführen, gibt es acht Chips pro
Slice und vier Slices werden insgesamt 32 Chips (Geräte) zurückgegeben.
von jax.devices()
.
QR-Codes auflisten
Mit dem Befehl queued-resources list
können Sie den Status Ihrer QR-Codes aufrufen:
$ gcloud compute tpus queued-resources list NAME ZONE NODE_COUNT ACCELERATOR_TYPE STATE ... que-res-id us-central2-b 4 v4-16 ACTIVE ...
QR-Codes beschreiben
Um die detaillierte Konfiguration und den Status eines QR-Codes anzuzeigen, verwende die
um die QR-API zu beschreiben. Sie können diese API mit gcloud
oder curl
aufrufen.
mit gcloud
:
$ gcloud compute tpus queued-resources describe ${your-qr-id} ...state: state: ACTIVE ...
mit curl
:
$ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/your-project-id/locations/your-zone/queuedResources/${your-qr-id} { "name": your-queued-res, "tpu": { "nodeSpec": [ { ... // node 1 }, { ... // node 2 }, ... ] }, ... "state": "ACTIVE" }
state
steht für den Status eines QR-Codes. Weitere Informationen zu den möglichen
Informationen zum Status von QR-Codes finden Sie unter Ressourcen in der Warteschlange.
Job in einer bereitgestellten Umgebung starten
Sie können Arbeitslasten manuell ausführen, indem Sie über SSH eine Verbindung zu allen Hosts in jedem Segment herstellen und führen den folgenden Befehl auf allen Hosts aus.
$ gcloud compute tpus tpu-vm ssh your-qr-id \ --zone=your-zone \ --worker=all \ --node=all \ --command="command-to-run"
QR-Codes zurücksetzen
Die ResetQueuedResource
API kann zum Zurücksetzen verwendet werden
alle VMs in einem ACTIVE
-QR-Code. Durch das Zurücksetzen der VMs wird der Speicher der Maschine gelöscht und die VM auf ihren Ausgangszustand zurückgesetzt. Alle lokal gespeicherten Daten bleiben intakt und das Startskript wird nach dem Zurücksetzen aufgerufen. Die
Die ResetQueuedResource
API kann nützlich sein, wenn Sie alle TPUs neu starten möchten. Für
z. B. wenn das Training hängen bleibt und das Zurücksetzen aller VMs einfacher ist als das Debugging.
Das Zurücksetzen aller VMs erfolgt parallel und ein ResetQueuedResource
-Vorgang dauert ein bis zwei Minuten. Verwenden Sie den folgenden Befehl, um die API aufzurufen:
$ gcloud compute tpus queued-resources reset your-qr-id
QR-Codes löschen
Wenn Sie Ressourcen am Ende der Trainingssitzung freigeben möchten, löschen Sie die in der Warteschlange befindliche Ressource mit dem Flag --force
. Der Löschvorgang dauert zwei bis fünf Minuten,
abgeschlossen und kann mit dem optionalen Flag --async
im Hintergrund ausgeführt werden.
$ gcloud compute tpus queued-resources \ delete your-qr-id --force (--async)
Automatische Wiederherstellung nach Fehlern
Bei einer Störung bietet Multislice eine störungsfreie Reparatur des betroffenen Slices und ein anschließendes Zurücksetzen aller Slices. Betroffene Slice wird durch ein neues Slice und die restlichen ansonsten fehlerfreien Slices ersetzt. zurückgesetzt. Wenn keine Kapazität zum Zuweisen eines Ersatz-Slice verfügbar, wird das Training beendet.
Wenn das Training nach einer Unterbrechung automatisch fortgesetzt werden soll, müssen Sie ein Startscript angeben, das nach den zuletzt gespeicherten Checkpoints sucht und sie lädt. Das Startskript wird jedes Mal automatisch ausgeführt, wenn ein Slab neu zugewiesen oder eine VM zurückgesetzt wird. Sie geben ein Start-up an in der JSON-Nutzlast, die Sie an die API zum Erstellen der QR-Anfrage senden.
Mit dem folgenden Startskript (verwendet in QR-Codes erstellen) können Sie automatisch nach Fehlern wiederherstellen und das Training anhand von Checkpoints fortsetzen, die während des MaxText-Trainings in einem Cloud Storage-Bucket gespeichert wurden:
{ "tpu": { "node_spec": [ { ... "metadata": { "startup-script": "#! /bin/bash \n pwd \n runuser -l user1 -c 'cd /home/user1/MaxText && python3 MaxText/train.py MaxText/configs/base.yml run_name=run_test_failure_recovery dcn_data_parallelism=4 ici_fsdp_parallelism=8 steps=10000 save_period=10 base_output_directory='gs://user1-us-central2'' EOF" } ... } ] } }
Klonen Sie das MaxText-Repository, bevor Sie dies ausprobieren.
Profiling und Fehlerbehebung
Das Profiling ist in Umgebungen mit einer einzelnen und mehreren Scheiben identisch. Weitere Informationen finden Sie unter Profilerstellung für JAX-Programme.
Optimierte Schulungen
Fragmentierung mit Multislice für maximale Leistung
Um in Multislice-Umgebungen maximale Leistung zu erzielen, über mehrere Segmente aufteilen. Es gibt in der Regel drei Möglichkeiten: Datenparallelität, vollständig shardete Datenparallelität und Pipelineparallelität. Wir raten davon ab, Aktivierungen auf die Modelldimensionen zu verteilen (manchmal Tensor-Parallelität), da dafür zu viel Bandbreite zwischen den Segmenten benötigt wird. Bei allen diesen Strategien können Sie innerhalb eines Segments dieselbe Sharding-Strategie beibehalten, die in der Vergangenheit für Sie funktioniert hat.
Wir empfehlen, mit reiner Datenparallelität zu beginnen. Die Verwendung von vollständig shardeten Datenparallelismen ist nützlich, um die Arbeitsspeichernutzung zu verringern. Der Nachteil ist, dass die Kommunikation zwischen den Slices über das DCN-Netzwerk erfolgt und die Arbeitslast verlangsamt. Pipeline-Parallelität nur verwenden, wenn dies basierend auf der Batchgröße erforderlich ist (wie unten analysiert).
Wann sollte Datenparallelität verwendet werden?
Die reine Datenparallelität funktioniert gut, wenn Sie eine Arbeitslast haben, aber Sie möchten die Leistung verbessern, indem Sie in mehreren Segmenten.
Um eine starke Skalierung über mehrere Segmente hinweg zu erzielen, für die vollständige Reduzierung im DCN darf der erforderliche Zeitraum nicht unterschreiten. für einen Rückwärtsdurchlauf. DCN wird für die Kommunikation zwischen Slices und ein begrenzender Faktor für den Arbeitslastdurchsatz.
Jeder TPU v4-Chip erreicht eine Spitzenleistung von 275 * 1012 FLOPs pro Sekunde.
Es gibt vier Chips pro TPU-Host und jeder Host hat eine maximale Netzwerkbandbreite von 50 Gbit/s.
Die arithmetische Intensität beträgt also 4 × 275 × 1012 FLOPS ÷ 50 Gbit/s = 22.000 FLOPS ÷ Bit.
Ihr Modell verwendet für jeden Parameter pro Schritt 32 bis 64 Bit DCN-Bandbreite. Wenn Sie zwei Slices verwenden, nutzt Ihr Modell 32 Bit DCN-Bandbreite. Wenn Sie Wenn Sie mehr als zwei Slices verwenden, führt der Compiler einen vollständigen Shuffle All-Reduce durch. und Sie verwenden bis zu 64 Bit DCN-Bandbreite für jeden Parameter pro Schritt. Wie viele FLOPS für jeden Parameter erforderlich sind, modellieren. Insbesondere bei Transformer-basierten Sprachmodellen ist die Anzahl der FLOPS Erforderlich für eine Vorwärts- und Rückwärtsterminierung sind ungefähr 6 * B * P, wobei Folgendes gilt:
- B ist die Batchgröße in Tokens.
- P ist die Anzahl der Parameter.
Die Anzahl der FLOPS pro Parameter ist 6 * B
und die Anzahl der FLOPS pro Parameter
beim Rückwärtstermin ist 4 * B
.
Um eine starke Skalierung über mehrere Slices hinweg zu gewährleisten, muss der operative
Intensität überschreitet die arithmetische Intensität der TPU-Hardware. Um die Betriebsintensität zu berechnen, teilen Sie die Anzahl der FLOPS pro Parameter während des Rückwärtsdurchlaufs durch die Netzwerkbandbreite (in Bits) pro Parameter pro Schritt:
Operational Intensity = FLOPSbackwards_pass / DCN bandwidth
Wenn Sie also für ein Transformer-basiertes Sprachmodell zwei Chunks verwenden:
Operational intensity = 4 * B / 32
Wenn Sie mehr als zwei Segmente verwenden: Operational intensity = 4 * B/64
Für Transformer-basierte Sprachmodelle wird eine Mindestbatchgröße von 176.000 bis 352.000 empfohlen. Da das DCN Pakete kurzzeitig löschen kann, um eine signifikante Fehlerspanne beizubehalten und nur Datenparallelität bereitzustellen. wenn die Batchgröße pro Pod mindestens 350.000 (zwei Pods) bis 700.000 (viele Pods) beträgt.
Bei anderen Modellarchitekturen müssen Sie die Laufzeit des Rückwärtsdurchlaufs pro Schleife schätzen. Dazu können Sie entweder einen Profiler verwenden oder die FLOPS zählen. Dann können Sie dies mit der erwarteten Laufzeit vergleichen, und eine gute Einschätzung dessen, ob eine Datenparallelität für Sie sinnvoll ist.
Wann sollte die vollständige Shard-Datenparallelität verwendet werden?
Bei der vollständigen Sharding-Datenparallelität (Fully Sharded Data Parallelism, FSDP) werden Datenparallelität (Sharding der Daten auf Knoten) und Gewichts-Sharding auf Knoten kombiniert. Für jeden Vorgang im Vorwärts- und Rückwärtsdurchlauf werden alle Gewichte zusammengetragen, damit jede Scheibe die benötigten Gewichte hat. Anstatt die Gradienten mit All-Reduce zu synchronisieren, werden sie beim Erzeugen mit Reduce-Scatter verteilt. So erhält jeder Sliver nur die Gradienten für die Gewichte, für die er verantwortlich ist.
Ähnlich wie bei der Datenparallelität erfordert FSDP eine Skalierung der globalen Batchgröße linear mit der Anzahl der Segmente. Mit FSDP wird der Speicherdruck verringert, wenn Sie die Anzahl der Scheiben erhöhen. Das liegt daran, dass die Anzahl der Gewichte und der Optimizer-Status pro Slither verringert wird. Dies geht jedoch zu Lasten des Netzwerkverkehrs und erhöht die Wahrscheinlichkeit von Blockierungen aufgrund eines verzögerten Kollektivs.
In der Praxis ist FSDP segmentübergreifend am besten geeignet, wenn Sie den Batch pro mehr Aktivierungen speichern, um die Re-Materialisierung während der die Anzahl der Parameter in Ihrem neuronalen Netzwerk rückwärts durchläuft.
Die Operationen „All-Database“ und „All-Reduce“ bei FSDP funktionieren ähnlich wie bei DP, So können Sie feststellen, ob Ihre FSDP-Arbeitslast durch die DCN-Leistung in wie im vorherigen Abschnitt beschrieben.
Wann sollte Pipelineparallelität verwendet werden?
Der Pipeline-Parallelismus wird relevant, wenn Sie mit anderen Parallelisierungsstrategien eine hohe Leistung erzielen möchten, für die eine globale Batchgröße erforderlich ist, die über Ihrer bevorzugten maximalen Batchgröße liegt. Durch die Pipelineparallelität können die Segmente einer Pipeline einen Batch „teilen“. Die Pipeline-Parallelität hat jedoch wesentliche Nachteile:
- Es kommt zu einer „Pipeline-Blase“, bei der Chips inaktiv sind, weil sie auf Daten warten.
- Es erfordert Mikro-Batching, wodurch sich die effektive Batchgröße verringert, die arithmetische Intensität und schließlich die FLOP-Auslastung.
Die Pipeline-Parallelität sollte nur verwendet werden, wenn die anderen Parallelitätsstrategien eine zu große globale Batchgröße erfordern. Bevor Sie die Pipelineparallelität ausprobieren, sollten Sie empirisch testen, ob sich die Konvergenz pro Stichprobe bei der Batchgröße verlangsamt, die für eine hohe FSDP-Leistung erforderlich ist. Bei der FSDP wird in der Regel eine höhere FLOP-Nutzung des Modells erreicht. Wenn sich die Konvergenz pro Stichprobe jedoch mit zunehmender Batchgröße verlangsamt, ist der Pipeline-Parallelismus möglicherweise die bessere Wahl. Meiste dass Arbeitslasten ausreichend große Batchgrößen tolerieren können, Pipeline-Parallelität. Ihre Arbeitslast kann jedoch abweichen.
Wenn eine Pipeline-Parallelität erforderlich ist, empfehlen wir, sie mit Daten zu kombinieren. oder FSDP. So können Sie die Pipelinetiefe minimieren, während dass die Batchgröße pro Pipeline erhöht wird, bis die DCN-Latenz abnimmt. Durchsatz berücksichtigt wird. Wenn Sie beispielsweise N Slices haben, sollten Sie Pipelines mit einer Tiefe von 2 und N/2 Repliken der Datenparallelität, dann Pipelines mit einer Tiefe von 4 und N/4 Repliken der Datenparallelität usw. in Betracht ziehen, bis der Batch pro Pipeline groß genug ist, dass die DCN-Kollektive hinter der Arithmetik im Rückwärtsdurchlauf verborgen werden können. Dadurch wird die durch den Pipelineparallelismus verursachte Verlangsamung minimiert und Sie können über das globale Limit für die Batchgröße hinaus skalieren.
Best Practices für Mehrfachaufnahmen
Laden der Daten
Während des Trainings werden wiederholt Batches aus einem Dataset geladen, um Daten in den modellieren. Ein effizientes asynchrones Datenladeprogramm, das den Batch auf ist wichtig, damit die TPUs nicht ausgebremst werden. Das aktuelle Datenladeprogramm in MaxText verwendet für jede Hostlast eine gleichwertige Teilmenge der Beispiele. Diese Lösung ist für Text geeignet, erfordert aber eine Neuaufteilung innerhalb des Modells. Zusätzlich kann MaxText bietet noch keine deterministische Snapshot-Erstellung, die es dem Daten-Iterator ermöglichen würde, um vor und nach dem vorzeitigen Beenden die gleichen Daten zu laden.
Prüfpunktausführung
Die Prüfpunktbibliothek Orbax bietet
Primitiven für die Prüfpunktausführung von JAX-PyTrees auf lokalen Speicher oder Google Cloud-Speicher.
Wir bieten eine Referenzintegration mit synchroner Prüfpunktausführung in MaxText.
in „checkpointing.py
“.
Unterstützte Konfigurationen
Formen
Alle Scheiben müssen dieselbe Form haben (z. B. dieselbe AcceleratorType
). Heterogene Scheibenformen werden nicht unterstützt.
Orchestrierung
Die Orchestrierung wird mit GKE unterstützt. Weitere Informationen finden Sie unter TPUs in GKE.
Frameworks
Multislice unterstützt nur JAX- und PyTorch-Arbeitslasten.
Parallelität
Wir empfehlen Nutzern, „Multislice“ mit Datenparallelität zu testen. Weitere Informationen zur Implementierung der Pipeline-Parallelität mit Multislice erfahren Sie von Ihrem Google Cloud-Kundenbetreuer
Support und Feedback
Wir freuen uns über jedes Feedback. Wenn du Feedback geben oder Unterstützung anfordern möchtest, wende dich an uns mithilfe des Support- oder Feedbackformulars für Cloud TPU.