Cloud TPU-Multislice – Übersicht

Cloud TPU Multislice ist eine Technologie zur Leistungsskalierung des Full-Stacks, mit der ein Trainingsjob mehrere TPU-Slices innerhalb eines einzelnen Pods oder auf Slices in mehreren Pods mit einfacher Datenparallelität verwenden kann. Mit TPU v4-Chips bedeutet dies, dass Trainingsjobs mehr als 4.096 Chips in einer einzigen Ausführung verwenden können. Bei Trainingsjobs, die weniger als 4.096 Chips benötigen, kann ein einzelnes Slice die beste Leistung bieten. Mehrere kleinere Segmente sind jedoch leichter verfügbar, was eine schnellere Startzeit ermöglicht, wenn Multislice mit kleineren Segmenten verwendet wird.

Mehrere Slices skalieren die Leistung linear.

Bei der Bereitstellung in Multislice-Konfigurationen kommunizieren TPU-Chips in jedem Segment über Inter-Chip-Interconnect (ICI). TPU-Chips in verschiedenen Teilen kommunizieren durch Übertragung von Daten an CPUs (Hosts), die die Daten wiederum über das Rechenzentrumsnetzwerk (DCN) übertragen.

Multislice-Datenfluss

Entwickler müssen keinen Code schreiben, um die Inter-Slice-DCN-Kommunikation zu implementieren. Der XLA-Compiler generiert diesen Code für Sie und überschneidet sich zur Leistungsoptimierung mit der Kommunikation.

Konzepte

Beschleunigertyp
Die Form jedes TPU-Slice, der aus einem Multislice besteht. Jeder Slice in einer Anfrage mit mehreren Segmenten hat denselben Beschleunigertyp. Ein Beschleunigertyp besteht aus einem TPU-Typ (v4 oder v5e) gefolgt von der Anzahl der TensorCores. Zum Beispiel gibt v4-128 eine TPU v4 mit 128 TensorCores an.
Automatische Reparatur
Wenn bei einem Slice ein Wartungsereignis, eine vorzeitige Beendigung oder ein Hardwarefehler auftritt, erstellt Cloud TPU ein neues Slice. In dem seltenen Fall, dass nicht genügend Ressourcen zum Erstellen eines neuen Slice vorhanden sind, wird die Erstellung erst abgeschlossen, wenn Hardware verfügbar ist. Nachdem das neue Slice erstellt wurde, werden alle anderen Slices in der Multislice-Umgebung neu gestartet, damit das Training fortgesetzt werden kann. Mit einem korrekt konfigurierten Startskript kann das Trainingsskript ohne Nutzereingriff automatisch neu gestartet werden. Es wird vom letzten Checkpoint geladen und fortgesetzt.
Dataset
Die Daten, die von einem Modell für Training oder Inferenz verwendet werden.
Rechenzentrumsnetzwerk (DCN)
Ein Netzwerk mit höherer Latenz und geringerem Durchsatz (im Vergleich zu ICI), das TPU-Slices in einer Multi-Slice-Konfiguration verbindet.
Gruppenplanung
Wenn alle TPU-Slices gleichzeitig bereitgestellt werden, wird garantiert, dass entweder alle oder keine der Slices erfolgreich bereitgestellt werden.
Moderator:in
Ein Host ist ein physischer Computer, auf dem VMs ausgeführt werden. Auf einem Host können maximal vier VMs gleichzeitig ausgeführt werden. Jede VM hat eine dedizierte TPU.
Inferenz
Laden Sie ein vortrainiertes Modell für maschinelles Lernen auf einen Host und treffen Sie Vorhersagen für Daten.
Interchip Interconnect (ICI)
Interne Verbindungen mit hoher Geschwindigkeit und niedriger Latenz, die TPUs in einem TPU Pod verbinden
Mehrfachschnitt
Zwei oder mehr TPU-Chipsegmente, die über DCN kommunizieren können.
Knoten
Im Multislice-Kontext bezieht sich der Knoten auf ein einzelnes TPU-Slice. Jedes TPU-Slice in einem Multislice erhält eine Knoten-ID.
Pod
Eine Sammlung von TPU-Chips, die über dedizierte ICI-Netzwerkschnittstellen verbunden sind. Mit einem Pod können Sie die Verarbeitungslast auf mehrere TPUs verteilen.
Ressource in Warteschlange (QR)
Eine Darstellung von TPU-Ressourcen, die verwendet wird, um eine Anfrage für eine TPU-Umgebung mit einem oder mehreren Segmenten in die Warteschlange zu stellen und zu verwalten.
Startskript
Ein standardmäßiges Compute Engine-Startskript, das jedes Mal ausgeführt wird, wenn eine VM gestartet oder neu gestartet wird. Für Multislice wird er in der Anfrage zur QR-Erstellung angegeben. Weitere Informationen zu Cloud TPU-Startskripts finden Sie unter TPU-Ressourcen verwalten.
TPU-Slice
Ein logischer Unterabschnitt eines TPU-Pod besteht aus TPU-Chips. Alle Chips in einem Slice kommunizieren über das ICI-Netzwerk miteinander.
TPU-VM
Eine virtuelle Maschine, auf der Linux ausgeführt wird und die Zugriff auf die zugrunde liegenden TPUs hat. Bei v4-TPUs hat jede TPU-VM direkten Zugriff auf vier Chips. Manchmal wird eine TPU-VM als Worker bezeichnet.
Google Tensor
Eine Datenstruktur, mit der mehrdimensionale Daten in einem Modell für maschinelles Lernen dargestellt werden.
Tensor Processing Unit (TPU)
Der intern von Google entwickelte Chip für die ML-Beschleunigung. Sie wurden entwickelt, um ein schnelles und energieeffizientes Computing für wichtige ML-Aufgaben wie die Matrixmultiplikation zu ermöglichen.
Arten von Cloud TPU-Kapazitäten

TPUs können aus drei Arten von Kapazitäten erstellt werden (siehe Nutzungsoptionen unter Funktionsweise von TPU-Preisen) :

  • Reservierung: Targeting auf reserviertes Kontingent. Damit Sie reserviertes Kontingent nutzen können, müssen Sie eine Reservierungsvereinbarung mit Google haben. Verwenden Sie beim Erstellen der Ressourcen das Flag --reserved.
  • Auf Abruf: Targeting auf ein auf Abruf verfügbares Kontingent. Ihre Ressourcen können vorzeitig beendet werden, um Platz für Anfragen für einen Job mit höherer Priorität zu schaffen. Verwenden Sie beim Erstellen der Ressourcen das Flag --best-effort.
  • On-Demand: Ist auf ein On-Demand-Kontingent ausgerichtet, das keine Reservierung erfordert und nicht vorzeitig beendet wird. Die TPU-Anfrage wird in eine On-Demand-Kontingentwarteschlange eingereiht, die von Cloud TPU angeboten wird. Die Verfügbarkeit von Ressourcen wird nicht garantiert. Standardmäßig ausgewählt, keine Flags erforderlich.

Mehr erfahren

Wenn Sie noch keine TPUs verwendet haben, installieren Sie zuerst die Google Cloud CLI und richten Sie Ihre Cloud TPU-Umgebung ein. Wenn Sie Multislice verwenden möchten, müssen Ihre TPU-Ressourcen als Ressourcen in der Warteschlange verwaltet werden.

Wenn Sie ein TPU v4-Nutzer sind und eine Reservierung haben, müssen Sie diese möglicherweise zu einem neuen Reservierungssystem migrieren. Weitere Informationen erhalten Sie von Ihrem Google Cloud-Kundenbetreuer.

Einleitendes Beispiel

In dieser Anleitung wird Code aus dem MaxText-GitHub-Repository verwendet. MaxText ist ein leistungsstarkes, beliebig skalierbares Open-Source-basiertes, gut getestetes grundlegendes LLM, das in Python und Jax geschrieben wurde. MaxText wurde für ein effizientes Training in Cloud TPU entwickelt.

Der Code in shardings.py soll Ihnen den Einstieg in das Experimentieren mit verschiedenen Parallelitätsoptionen erleichtern. Zum Beispiel Datenparallelität, vollständig fragmentierte Datenparallelität (FSDP) und Tensorparallelität. Der Code skaliert von Umgebungen mit einem einzelnen Segment bis zu Umgebungen mit mehreren Teilen.

ICI-Parallelität

ICI ist die Hochgeschwindigkeits-Interconnect-Verbindung, die die TPUs in einem einzelnen Slice verbindet. Die ICI-Fragmentierung entspricht der Fragmentierung innerhalb eines Slice. shardings.py stellt drei ICI-Parallelitätsparameter bereit:

  • ici_data_parallelism
  • ici_fsdp_parallelism
  • ici_tensor_parallelism

Die Werte, die Sie für diese Parameter angeben, bestimmen die Anzahl der Shards für jede Parallelitätsmethode.

Diese Eingaben müssen begrenzt werden, damit ici_data_parallelism * ici_fsdp_parallelism * ici_tensor_parallelism der Anzahl der Chips im Slice entspricht.

Die folgende Tabelle zeigt Beispielnutzereingaben für die ICI-Parallelität für die vier in v4-8 verfügbaren Chips:

ici_data_parallelism ici_fsdp_parallelism ici_tensor_parallelism
4-Wege-FSDP 1 4 1
4-Wege-Tensor-Parallelität 1 1 4
2-Wege-FSDP + 2-Wege-Tensor-Parallelität 1 2 2

Beachten Sie, dass ici_data_parallelism in den meisten Fällen auf 1 bleiben sollte, da das ICI-Netzwerk schnell genug ist, um FSDP der Datenparallelität fast immer vorzuziehen.

Bei diesem Beispiel wird davon ausgegangen, dass Sie mit dem Ausführen von Code auf einem einzelnen TPU-Slice vertraut sind, wie z. B. im Artikel Berechnung auf einer Cloud TPU-VM mit JAX ausführen beschrieben. In diesem Beispiel wird gezeigt, wie shardings.py für ein einzelnes Slice ausgeführt wird.

  1. Richten Sie die Umgebung ein:

    $ gcloud auth login
    $ gcloud config set project your-project-id
    $ gcloud config set compute/zone your-zone
    
  2. Erstellen Sie SSH-Schlüssel für gcloud. Wir empfehlen, ein leeres Passwort zu lassen. Drücken Sie dazu zweimal die Eingabetaste, nachdem Sie den folgenden Befehl ausgeführt haben. Wenn Sie gefragt werden, ob die Datei google_compute_engine bereits vorhanden ist, ersetzen Sie die vorhandene Version.

    $ ssh-keygen -f ~/.ssh/google_compute_engine
    
  3. Stellen Sie Ihre TPUs mit dem folgenden Befehl bereit:

    $ gcloud alpha compute tpus queued-resources \
    create your-qr-id \
    --accelerator-type your-accelerator-type \
    --runtime-version tpu-ubuntu2204-base \
    --node-id qr-id \
    [--reserved |--best-effort]
    

    Beschreibung der Befehls-Flags

    your-qr-id
    Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
    accelerator-type
    Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    runtime-version
    Die [Cloud TPU-Softwareversion](/tpu/docs/supported-tpu-configurations#tpu_software_versions).
    node-id
    Die ID der TPU-Ressourcen, die als Antwort auf die QR-Anfrage erstellt werden.
    reserved
    Beim Erstellen der Slices ein reserviertes Kontingent verwenden
    best-effort
    Beim Erstellen der Segmente Best-Effort-Kontingent verwenden [Standardeinstellung]

    Die Google Cloud CLI unterstützt nicht alle Optionen zum Erstellen von QR-Codes, z. B. Tags. Weitere Informationen finden Sie unter QRs erstellen.

  4. Warten Sie, bis der QR-Code den Status ACTIVE hat, was bedeutet, dass sich die Worker-Knoten im Status READY befinden. Sobald die Bereitstellung des QR-Codes gestartet wurde, kann es je nach Größe des QR-Codes ein bis fünf Minuten dauern, bis die Bereitstellung abgeschlossen ist. Sie können den Status einer QR-Anfrage mit dem folgenden Befehl prüfen:

    $ gcloud alpha compute tpus queued-resources \
      list --filter=your-qr-id
    
  5. Ein v4-8-Slice hat eine einzelne TPU-VM. Stellen Sie über SSH eine Verbindung zur TPU-VM her:

    $ gcloud compute tpus tpu-vm ssh your-qr-id
    
  6. Klonen Sie MaxText (einschließlich shardings.py) auf Ihrer TPU-VM.

  7. Führen Sie im MaxText-Repository-Verzeichnis das Einrichtungsskript aus, um JavaX und andere Abhängigkeiten auf Ihrem TPU-Slice zu installieren. Das Ausführen des Setup-Skripts dauert einige Minuten.

    $ bash setup.sh
    
  8. Führen Sie den folgenden Befehl aus, um shardings.py für Ihr TPU-Slice auszuführen.

    $ python3 pedagogical_examples/shardings.py \
      --ici_fsdp_parallelism 4 \
      --batch_size 131072 \
      --embedding_dimension 2048
    

    Die Ergebnisse finden Sie in den Logs. Ihre TPUs sollten etwa 260 TFLOP pro Sekunde oder eine beeindruckende FLOP-Auslastung von mehr als 90%erreichen. In diesem Fall haben wir ungefähr den maximalen Batch ausgewählt, der in den High Bandwidth Memory (HBM) der TPU passt.

  9. Sie können auch über ICI andere Fragmentierungsstrategien ausprobieren, beispielsweise mit der folgenden Kombination:

    $ python3 pedagogical_examples/shardings.py \
      --ici_tensor_parallelism 4 \
      --batch_size 131072 \
      --embedding_dimension 2048
    
  10. Löschen Sie anschließend das QR- und das TPU-Slice. Sie sollten diese Bereinigungsschritte in der Umgebung ausführen, in der Sie das Slice eingerichtet haben. Führen Sie zuerst exit aus, um die SSH-Sitzung zu beenden. Das Löschen dauert zwei bis fünf Minuten und kann mit dem optionalen Flag --async im Hintergrund ausgeführt werden.

    $ gcloud alpha compute tpus queued-resources
      delete your-qr-id --force (--async)
    

Mehrfachfragmentierung mit DCN-Parallelität

Das Skript shardings.py verwendet drei Parameter, die die DCN-Parallelität angeben, die der Anzahl der Shards für jede Art von Datenparallelität entsprechen:

  • dcn_data_parallelism
  • dcn_fsdp_parallelism
  • dcn_tensor_parallelism

Die Werte dieser Parameter müssen so begrenzt sein, dass dcn_data_parallelism * dcn_fsdp_parallelism * dcn_tensor_parallelism der Anzahl der Sektoren entspricht.

Verwenden Sie als Beispiel für zwei Segmente --dcn_data_parallelism = 2.

dcn_data_parallelism dcn_fsdp_parallelism dcn_tensor_parallelism Anzahl der Sektoren
Zweiseitige Datenparallelität 2 1 1 2

dcn_tensor_parallelism sollte immer auf 1 gesetzt werden, da das DCN nicht für eine solche Fragmentierung geeignet ist. Für typische LLM-Arbeitslasten auf v4-Chips sollte dcn_fsdp_parallelism ebenfalls auf 1 und daher dcn_data_parallelism auf die Anzahl der Slices festgelegt werden. Dies ist jedoch anwendungsabhängig.

Wenn Sie die Anzahl der Slices erhöhen (vorausgesetzt, Sie halten die Slice-Größe und den Batch pro Slice konstant), erhöhen Sie die Menge der Datenparallelität.

shardings.py in einer Umgebung mit mehreren Teilen ausführen

Sie können shardings.py in einer Umgebung mit mehreren Schichten mithilfe von multihost_runner.py oder durch Ausführen von shardings.py auf jeder TPU-VM ausführen. Hier verwenden wir multihost_runner.py. Die folgenden Schritte sind denen unter Erste Schritte: Schnelle Tests für mehrere Segmente aus dem MaxText-Repository sehr ähnlich, mit der Ausnahme, dass hier shardings.py anstelle des komplexeren LLM in train.py ausgeführt wird.

Das multihost_runner.py-Tool ist für schnelle Tests optimiert, bei denen wiederholt dieselben TPUs verwendet werden. Da das Skript multihost_runner.py von langlebigen SSH-Verbindungen abhängt, empfehlen wir es nicht für Jobs mit langer Ausführungszeit. Wenn Sie einen längeren Job ausführen möchten (z. B. Stunden oder Tage), empfehlen wir die Verwendung von multihost_job.py.

In dieser Anleitung verwenden wir den Begriff runner, um die Maschine anzugeben, auf der Sie das Skript multihost_runner.py ausführen. Wir verwenden den Begriff Worker, um die TPU-VMs anzugeben, aus denen Ihre Slices bestehen. Sie können multihost_runner.py auf einem lokalen Computer oder einer beliebigen Compute Engine-VM im selben Projekt wie die Segmente ausführen. Das Ausführen von multihost_runner.py auf einem Worker wird nicht unterstützt.

multihost_runner.py stellt automatisch über SSH eine Verbindung zu TPU-Workern her.

In diesem Beispiel wird shardings.py über zwei v4-16-Segmente ausgeführt, also insgesamt vier VMs und 16 TPU-Chips. Sie können das Beispiel so ändern, dass es auf mehr TPUs ausgeführt wird.

Umgebung einrichten

  1. Klonen Sie MaxText auf dem Läufercomputer.

  2. Rufen Sie das Repository-Verzeichnis auf.

  3. Erstellen Sie SSH-Schlüssel für gcloud. Wir empfehlen, ein leeres Passwort zu lassen (drücken Sie die Eingabetaste zweimal, nachdem Sie den folgenden Befehl ausgeführt haben). Wenn Sie gefragt werden, ob die Datei google_compute_engine bereits vorhanden ist, wählen Sie aus, dass die vorhandene Version nicht beibehalten werden soll.

      $ ssh-keygen -f ~/.ssh/google_compute_engine
      

  4. Fügen Sie eine Umgebungsvariable hinzu, um die Anzahl der TPU-Slices auf 2 festzulegen.

      $ export SLICE_COUNT=2
      

  5. Erstellen Sie mit queued-resources create eine Umgebung mit mehreren Scheiben.

    Der folgende Befehl zeigt, wie eine v4-Multislice-TPU erstellt wird. Wenn Sie v5e verwenden möchten, geben Sie accelerator-type für V5e an (z. B. v5litepod-16) und für V5e-runtime-version (v2-alpha-tpuv5-lite).

      $ gcloud alpha compute tpus queued-resources 
    create your-qr-id
    --accelerator-type=your-accelerator-type
    --runtime-version=tpu-vm-runtime-version
    --node-count=node-count
    --node-prefix=your-qr-id
    [--reserved|--best-effort]

    Beschreibung der Befehls-Flags

    your-qr-id
    Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
    accelerator-type
    Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    runtime-version
    Die Cloud TPU-Softwareversion.
    node-count
    Die Anzahl der zu erstellenden Segmente.
    node-prefix
    Das Präfix, mit dem Namen für jedes Segment generiert werden. Für jedes Segment wird eine Zahl an das Präfix angehängt. Wenn Sie beispielsweise node-prefix auf mySlice setzen, werden die Segmente benannt: mySlice-0, mySlice-1 usw.
    reserved
    Beim Erstellen der Slices ein reserviertes Kontingent verwenden
    best-effort
    Beim Erstellen der Segmente Best-Effort-Kontingent verwenden [Standardeinstellung]

  6. Wenn die Bereitstellung des QR-Codes gestartet wird, kann es je nach Größe des QR-Codes bis zu 5 Minuten dauern, bis er abgeschlossen ist. Warten Sie, bis die Ressource in der Warteschlange (QR) den Status ACTIVE hat. Sie können den Status einer QR-Anfrage mit dem folgenden Befehl prüfen:

    $ gcloud alpha compute tpus queued-resources list \
    --filter=your-qr-id
    

    Die Ausgabe sollte in etwa so aussehen:

    NAME        ZONE           NODE_COUNT  ACCELERATOR_TYPE  STATE
    ...
    que-res-id  us-central2-b  4           v4-16             ACTIVE
    ...
    

    Wenden Sie sich an Ihren Google Cloud-Kundenbetreuer, wenn der QR-Status länger als 15 Minuten den Status WAITING_FOR_RESOURCES oder PROVISIONING hat.

  7. Installieren Sie die Abhängigkeiten:

    $ python3 multihost_runner.py \
      --TPU_PREFIX=your-qr-id \
      --COMMAND="bash setup.sh"
    
  8. Führen Sie mit multihost_runner.py auf jedem Worker shardings.py aus.

    $ python3 multihost_runner.py \
      --TPU_PREFIX=your-qr-id \
      --COMMAND="python3 pedagogical_examples/shardings.py \
      --dcn_data_parallelism $SLICE_COUNT \
      --ici_fsdp_parallelism 8 \
      --batch_size 131072 \
      --embedding_dimension 2048"
    

    In den Logdateien sehen Sie eine Leistung von etwa 230 TFLOPs pro Sekunde.

  9. Wenn Sie fertig sind, bereinigen Sie die TPUs und den QR-Code. Das Löschen dauert zwei bis fünf Minuten und kann mit dem optionalen Flag --async im Hintergrund ausgeführt werden.

Arbeitslast in mehrere Segmente skalieren

Bevor Sie das Modell in einer Umgebung mit mehreren Teilen ausführen, nehmen Sie die folgenden Codeänderungen vor:

Dies sollte die einzigen erforderlichen Codeänderungen sein, wenn Sie zu „Mehrfachsegmente“ wechseln. Um eine hohe Leistung zu erzielen, muss DCN parallelen, vollständig fragmentierten Daten parallelen oder Pipeline-Parallelachsen zugeordnet werden. Leistungsaspekte und Fragmentierungsstrategien werden im Abschnitt Fragmentierung mit mehreren Scheiben für maximale Leistung ausführlicher erläutert.

Wenn Sie prüfen möchten, ob Ihr Code auf alle Geräte zugreifen kann, können Sie bestätigen, dass len(jax.devices()) der Anzahl der Chips in Ihrer Multislice-Umgebung entspricht. Wenn Sie beispielsweise vier v4-16-Segmente verwenden, haben Sie acht Chips pro Slice * 4 Slices, sodass len(jax.devices()) 32 zurückgeben sollte.

Segmentgrößen für Umgebungen mit mehreren Scheiben auswählen

Für eine lineare Beschleunigung fügen Sie neue Segmente der Größe des vorhandenen Segments hinzu. Wenn Sie beispielsweise ein v4-512-Slice verwenden, wird mit dem Multislice eine etwa doppelt so hohe Leistung erzielt. Dazu wird ein zweites v4-512-Slice hinzugefügt und die globale Batchgröße verdoppelt. Weitere Informationen finden Sie unter Fragmentierung mit mehreren Scheiben für maximale Leistung.

Job in mehreren Segmenten ausführen

Es gibt drei verschiedene Ansätze zum Ausführen Ihrer benutzerdefinierten Arbeitslast in einer Umgebung mit mehreren Schneiden:

  1. Es wird das Skript für den Testlauf verwendet: multihost_runner.py
  2. Mit dem Script multihost_job.py für die Produktionsausführung
  3. Manueller Ansatz

Skript für den Testausführer

Das Skript multihost_runner.py verteilt Code auf eine vorhandene Umgebung mit mehreren Scheiben und führt den Befehl auf jedem Host aus, kopiert Ihre Logs zurück und verfolgt den Fehlerstatus jedes Befehls. Das Skript multihost_runner.py ist in der Readme-Datei für MaxText dokumentiert.

Da multihost_runner.py persistente SSH-Verbindungen beibehält, ist es nur für Tests mit mittlerer Größe und relativ kurzer Laufzeit geeignet. Sie können die Schritte in der Anleitung multihost_runner.py an Ihre Arbeitslast und Hardwarekonfiguration anpassen.

Skript für Produktionsausführer

Produktionsjobs, die Resilienz gegen Hardwarefehler und andere vorzeitige Beendigungen erfordern, sollten direkt in die Create Queued Resource API eingebunden werden. Als funktionierendes Beispiel stellen wir multihost_job.py bereit, das den Aufruf der Created Queued Resource API mit dem entsprechenden Startskript auslöst, um das Training auszuführen und beim vorzeitigen Beenden fortzufahren. Das Skript multihost_job.py ist in der Readme-Datei für MaxText dokumentiert.

Da multihost_job.py für jede Ausführung Ressourcen bereitstellen muss, bietet es keinen so schnellen Iterationszyklus wie multihost_runner.py.

Manueller Ansatz

Wir empfehlen, multihost_runner.py oder multihost_job.py zu verwenden oder anzupassen, um Ihre benutzerdefinierte Arbeitslast in Ihrer Multislice-Konfiguration auszuführen. Wenn Sie Ihre Umgebung jedoch lieber direkt mit QR-Befehlen bereitstellen und verwalten möchten, lesen Sie Umgebungen mit mehreren Segmenten verwalten.

Umgebung mit mehreren Segmenten verwalten

In den folgenden Abschnitten erfahren Sie, wie Sie QR-Codes manuell bereitstellen und verwalten können, ohne die Tools im MaxText-Repository zu verwenden.

Kurzantworten erstellen

Legen Sie die folgenden Umgebungsvariablen fest, bevor Sie Kapazität bereitstellen:

  $ export your-qr-id=your-queued-resource-id
  $ export PROJECT=your-project-name
  $ export ZONE=us-central2-b
  $ export NETWORK_NAME=your-network-name
  $ export SUBNETWORK_NAME=your-subnetwork-name
  $ export RUNTIME_VERSION=tpu-ubuntu2204-base
  $ export ACCELERATOR_TYPE=v4-16
  $ export SLICE_COUNT=4
  $ export STARTUP_SCRIPT="#!/bin/bash\n ..."
  $ gcloud config set project project-name
  $ gcloud config set compute/zone zone
Eingabe Beschreibung
your-qr-id Die vom Nutzer zugewiesene ID des QR-Codes.
PROJEKT Name des Google Cloud-Projekts
ZONE us-central2-b
NETWORK_NAME Name der VPC-Netzwerke.
SUBNETWORK_NAME Name des Subnetzes in VPC-Netzwerken
RUNTIME_VERSION TPU-Ubuntu2204-Base
ACCELERATOR_TYPE v4-16
EXAMPLE_TAG_1, EXAMPLE_TAG_2 ... Tags zur Identifizierung gültiger Quellen oder Ziele für Netzwerk-Firewalls
SLICE_COUNT Anzahl der Sektoren. Beschränkt auf maximal 256 Slices.
STARTUP_SCRIPT Wenn es der Erstellungsanfrage hinzugefügt wird, kann ein Startskript ausgeführt werden, wenn ein TPU-Slice bereitgestellt oder neu gestartet wird und das TPU-Slice repariert oder zurückgesetzt wird.

QR-Anfrage mit gcloud erstellen

$ gcloud alpha compute tpus queued-resources \
  create ${your-qr-id} \
  --project your-project-id \
  --zone your-zone \
  --node-count ${SLICE_COUNT} \
  --accelerator-type ${ACCELERATOR_TYPE} \
  --runtime-version ${RUNTIME_VERSION} \
  --network ${NETWORK_NAME} \
  --subnetwork ${SUBNETWORK_NAME} \
  --tags ${EXAMPLE_TAG_1},${EXAMPLE_TAG_2} \ --metadata=startup-script='${STARTUP_SCRIPT}'
  [--reserved|--best-effort]
  

Beschreibung der Befehls-Flags

your-qr-id
Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
project
Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
zone
Die Google Cloud-Zone, in der der QR-Code erstellt werden soll.
node-count
Die Anzahl der zu erstellenden Segmente.
accelerator-type
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
runtime-version
Die Cloud TPU-Softwareversion.
network
Der Name eines VPC-Netzwerk, an das die TPU-Ressource angehängt werden soll.
subnetwork
Der Name eines VPC-Subnetzwerks, an das die TPU-Ressource angehängt werden soll.
reserved
Beim Erstellen der Slices ein reserviertes Kontingent verwenden
best-effort
Beim Erstellen der Segmente Best-Effort-Kontingent verwenden [Standardeinstellung]

Prüfen Sie, ob Sie das entsprechende Kontingent haben, bevor Sie --reserved, --best_effort oder das standardmäßige On-Demand-Kontingent auswählen. Informationen zu Kontingenttypen finden Sie in der Kontingentrichtlinie.

QR-Anfrage mit curl erstellen

Erstellen Sie eine Datei mit dem Namen queued-resource-req.json und kopieren Sie die folgende JSON-Datei in diese Datei.

{
  "guaranteed": { "reserved": true },
  "tpu": {
    "node_spec": [
    {
      "parent": "projects/your-project-number/locations/your-zone",
        "node": {
          "accelerator_type": "accelerator-type",
          "runtime_version": "tpu-vm-runtime-version",
          "network_config": {
            "network": "your-network-name",
            "subnetwork": "your-subnetwork-name",
            "enable_external_ips": true
          },
          "tags" : ["example-tag-1"]
          "metadata": {
            "startup-script": "your-startup-script"
          }
      },
      "multi_node_params": {
        "node_count": slice-count,
        "node_id_prefix": "your-queued-resource-id"
      }
    }
    ]
  }
}
  • your-project-number – Ihre Google Cloud-Projektnummer
  • your-zone – Der Bereich, in dem Sie den QR-Code erstellen möchten
  • accelerator-type: Version und Größe eines einzelnen Segments
  • tpu-vm-runtime-version – Die TPU-VM-Laufzeitversionen
  • your-network-name: Optional: ein Netzwerk, mit dem der QR-Code verbunden wird
  • your-subnetwork-name: ein Subnetzwerk, an das der QR-Code angehängt wird (optional)
  • example-tag-1 – ein beliebiger Tag-String (optional)
  • your-startup-script – Ein Startskript, das bei der Zuweisung des QR-Codes ausgeführt wird
  • slice-count – die Anzahl der TPU-Slices in Ihrer Multislice-Umgebung
  • your-qr-id: die vom Nutzer angegebene ID für den QR-Code

Weitere Informationen zu allen verfügbaren Optionen finden Sie in der Dokumentation zur REST Queued Resource API.

Ersetzen Sie Folgendes, um Kapazität auf Abruf zu verwenden:

"guaranteed": { "reserved": true } mit "best_effort": {}

Alternativ können Sie die Linie entfernen, um die Standard-On-Demand-Kapazität zu verwenden.

Senden Sie die Anfrage zum Erstellen von QR-Codes mit der JSON-Nutzlast:

  $ curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" -d @queuedresourcereq.json https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/queuedResources\?queued_resource_id\=your-qr-id
  • your-project-id – Ihre Google Cloud-Projekt-ID
  • your-zone – Der Bereich, in dem Sie den QR-Code erstellen möchten
  • your-qr-id: die vom Nutzer angegebene ID für den QR-Code

Die Antwort sollte in etwa so aussehen:

{
  "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qr-guid>",
  "metadata": {
    "@type": "type.googleapis.com/google.cloud.common.OperationMetadata",
    "createTime": "2023-11-01T00:17:05.742546311Z",
    "target": "projects/<your-project-id>/locations/<your-zone>/queuedResources/<your-qa-id>",
    "verb": "create",
    "cancelRequested": false,
    "apiVersion": "v2alpha1"
  },
  "done": false
}

Verwenden Sie den GUID-Wert am Ende des Stringwerts für das Attribut name, um Informationen zur QR-Anfrage abzurufen.

Status eines QR-Codes abrufen

Verwenden Sie den folgenden Befehl, um den Status der QR-Anfrage abzurufen:

  $ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/operations/operation-your-qr-guid
  • your-project-id – Ihre Google Cloud-Projekt-ID
  • your-zone – Die Zone, in der der QR-Code erstellt werden soll.
  • your-qr-guid: Die GUID, die auf name in der Ausgabe der Anfrage zur QR-Erstellung folgt.

Die Antwort auf diesen Befehl enthält den Status des Vorgangs:

{
  "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qa-guid>,
  "metadata": {...},
  "done": true,
  "response": {
    "@type": "type.googleapis.com/google.cloud.tpu.v2alpha1.QueuedResource",
    ...
    "state": {
      "state": "WAITING_FOR_RESOURCES"
    }
  }
}

Wenn der QR-Code erfolgreich ("done = true") erstellt wurde, ist der Status im Feld response entweder WAITING_FOR_RESOURCES oder FAILED. Wenn der QR-Code den Status WAITING_FOR_RESOURCES hat, wurde der QR-Code in die Warteschlange gestellt und mit der Bereitstellung begonnen, sobald genügend Ressourcen vorhanden sind. Wenn der QR-Code den Status FAILED hat, wird der Fehlergrund in der Ausgabe angegeben. Weitere Informationen zu anderen möglichen Status finden Sie im Nutzerhandbuch für Ressourcen in der Warteschlange.

Sobald der Vorgang abgeschlossen ist, können Sie die Phasen des QR-Codes mit Beschreibung von QRs überwachen.

In einem seltenen Szenario kann der QR-Code den Status FAILED haben, während einige Slices den Status ACTIVE haben. Löschen Sie in diesem Fall die erstellten Ressourcen und versuchen Sie es in einigen Minuten noch einmal oder wenden Sie sich an das Cloud TPU-Team, um das Problem zu beheben.

SSH und Abhängigkeiten installieren

Unter JAX-Code auf TPU-Pod-Slices ausführen wird gezeigt, wie Sie mithilfe von SSH eine Verbindung zu Ihren TPU-VMs in einem einzelnen Segment herstellen. Verwenden Sie den folgenden gcloud-Befehl, um über SSH eine Verbindung zu allen TPU-VMs in Ihrer Multislice-Umgebung herzustellen und Abhängigkeiten zu installieren:

  $ gcloud compute tpus queued-resources ssh ${your-qr-id} \
    --zone your-zone \
    --node=all \
    --worker=all \
    --command="command-to-run"
    --batch-size=4

Mit diesem gcloud-Befehl wird der angegebene Befehl über SSH an alle Worker und Knoten in QR-Code gesendet. Der Befehl wird in Vierergruppen zusammengefasst und gleichzeitig gesendet. Der nächste Batch von Befehlen wird gesendet, wenn der aktuelle Batch die Ausführung abgeschlossen hat. Wenn bei einem der Befehle ein Fehler auftritt, wird die Verarbeitung beendet und es werden keine weiteren Batches gesendet. Weitere Informationen finden Sie in der API-Referenz für Ressourcen in Warteschlange. Wenn die Anzahl der verwendeten Slices das Threading-Limit Ihres lokalen Computers überschreitet (auch Batching-Limit genannt), kommt es zu einem Deadlock. Angenommen, das Batch-Limit auf Ihrem lokalen Computer beträgt 64. Wenn Sie versuchen, ein Trainingsskript für mehr als 64 Slices, z. B. 100, auszuführen, teilt der SSH-Befehl die Slices in Batches auf. Er führt das Trainingsskript für den ersten Batch von 64 Slices aus und wartet, bis die Skripts abgeschlossen sind, bevor er das Skript für den restlichen Batch von 36 Slices ausführt. Der erste Batch von 64 Segmenten kann jedoch erst abgeschlossen werden, wenn die verbleibenden 36 Slices mit der Ausführung des Skripts beginnen. Dies führt zu einem Deadlock.

Um dieses Szenario zu verhindern, können Sie das Trainingsskript auf jeder VM im Hintergrund ausführen. Dazu hängen Sie ein Et-Zeichen (&) an den Skriptbefehl an, den Sie mit dem Flag --command angeben. In diesem Fall kehrt die Steuerung sofort zum SSH-Befehl zurück, nachdem das Trainingsskript für den ersten Batch von Segmenten gestartet wurde. Der SSH-Befehl kann dann das Trainingsskript für den verbleibenden Batch von 36 Slices ausführen. Wenn Sie die Befehle im Hintergrund ausführen, müssen Sie Ihre stdout- und stderr-Streams entsprechend per Pipe verwenden. Um die Parallelität innerhalb desselben QR-Codes zu erhöhen, können Sie mit dem Parameter --node bestimmte Slices auswählen.

Netzwerkeinrichtung

Führen Sie die folgenden Schritte aus, damit TPU-Slices miteinander kommunizieren können. Installieren Sie JAX für jedes Segment. Weitere Informationen finden Sie unter JAX-Code auf TPU-Pod-Slices ausführen. Bestätigen Sie, dass len(jax.devices()) der Anzahl der Chips in Ihrer Multislice-Umgebung entspricht. Führen Sie dazu für jedes Slice folgenden Befehl aus:

  $ python3 -c 'import jax; print(jax.devices())'

Wenn Sie diesen Code in vier Segmenten von v4-16 ausführen, gibt es acht Chips pro Slice und vier Segmente, also sollten insgesamt 32 Chips (Geräte) von jax.devices() zurückgegeben werden.

Kurzantworten auflisten

Sie können den Status Ihrer QRs mit dem Befehl queued-resources list aufrufen:

$ gcloud alpha compute tpus queued-resources list

NAME        ZONE           NODE_COUNT  ACCELERATOR_TYPE  STATE
...
que-res-id  us-central2-b  4           v4-16             ACTIVE
...

Kurzantworten beschreiben

Um die detaillierte Konfiguration und den Status eines QR-Codes anzusehen, verwenden Sie die Beschreibung der QR API. Sie können diese API mit gcloud oder curl aufrufen.

mit gcloud:

$ gcloud alpha compute tpus queued-resources describe ${your-qr-id}
...state:
 state: ACTIVE
...

mit curl:

$ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/queuedResources/${your-qr-id}
{
  "name": your-queued-res,
  "tpu": {
    "nodeSpec": [
      {
        ... // node 1
      },
      {
        ... // node 2
      },
      ...
    ]
  },
  ...
  "state": "ACTIVE"
}

state steht für den Status eines QR-Codes. Weitere Informationen zu den möglichen Status von QRs finden Sie unter Ressourcen in der Warteschlange.

Job in einer bereitgestellten Umgebung starten

Sie können Arbeitslasten manuell ausführen. Stellen Sie dazu über SSH eine Verbindung zu allen Hosts in jedem Segment her und führen Sie den folgenden Befehl auf allen Hosts aus.

$ gcloud compute tpus tpu-vm ssh your-qr-id \
  --zone=your-zone \
  --worker=all \
  --node=all \
  --command="command-to-run"

Kurzantworten werden zurückgesetzt

Mit der ResetQueuedResource API können alle VMs mit einem ACTIVE-QR-Code zurückgesetzt werden. Durch das Zurücksetzen der VMs wird das Löschen des Arbeitsspeichers der Maschine erzwungen und die VM in ihren Ausgangszustand zurückgesetzt. Alle lokal gespeicherten Daten bleiben intakt und das Startskript wird nach einem Zurücksetzen aufgerufen. Die ResetQueuedResource API kann nützlich sein, wenn Sie alle TPUs neu starten möchten. Wenn beispielsweise das Training hängt und das Zurücksetzen aller VMs einfacher ist als das Debugging.

Alle VMs werden parallel zurückgesetzt. Ein ResetQueuedResource-Vorgang dauert ein bis zwei Minuten. Verwenden Sie den folgenden Befehl, um die API aufzurufen:

$ gcloud alpha compute tpus queued-resources reset your-qr-id

Kurzantworten werden gelöscht

Zum Freigeben von Ressourcen am Ende der Trainingssitzung löschen Sie die Ressource in der Warteschlange mit dem Flag --force. Das Löschen dauert zwei bis fünf Minuten und kann mit dem optionalen Flag --async im Hintergrund ausgeführt werden.

$ gcloud alpha compute tpus queued-resources \
delete your-qr-id --force (--async)

Automatische Wiederherstellung nach Fehlern

Im Falle einer Störung bietet Multislice eine interventionsfreie Reparatur des betroffenen Slice und Zurücksetzen aller Slices anschließend an. Das betroffene Slice wird durch ein neues ersetzt und die ansonsten fehlerfreien Slices werden zurückgesetzt. Wenn keine Kapazität zum Zuweisen eines Ersatzsegments verfügbar ist, wird das Training beendet.

Damit das Training nach einer Unterbrechung automatisch fortgesetzt wird, müssen Sie ein Startskript angeben, das die zuletzt gespeicherten Prüfpunkte prüft und lädt. Das Startskript wird automatisch ausgeführt, wenn ein Slice neu zugewiesen oder eine VM zurückgesetzt wird. Sie geben ein Startskript in der JSON-Nutzlast an, die Sie an die API zum Erstellen von QR-Anfragen senden.

Mit dem folgenden Startskript, das in QRs erstellen verwendet wird, können Sie automatisch nach Fehlern wiederherstellen und das Training an Prüfpunkten fortsetzen, die während des MaxText-Trainings in einem Cloud Storage-Bucket gespeichert wurden:

{
 "tpu": {
   "node_spec": [
     {
      ...
         "metadata": {
               "startup-script": "#! /bin/bash \n pwd \n runuser -l user1 -c 'cd /home/user1/MaxText && python3 MaxText/train.py MaxText/configs/base.yml run_name=run_test_failure_recovery dcn_data_parallelism=4 ici_fsdp_parallelism=8 steps=10000 save_period=10 base_output_directory='gs://user1-us-central2'' EOF"
         }
     ...
     }
   ]
 }
}

Klonen Sie das MaxText-Repository, bevor Sie dies ausprobieren.

Profilerstellung und Fehlerbehebung

Die Profilerstellung ist in Umgebungen mit einem einzelnen Segment und in Umgebungen mit mehreren Segmenten identisch. Weitere Informationen finden Sie unter Profilerstellung für JAX-Programme.

Optimierte Schulungen

Fragmentierung mit mehreren Segmenten für maximale Leistung

Um die maximale Leistung in Umgebungen mit mehreren Segmenten zu erreichen, muss berücksichtigt werden, wie die Fragmentierung auf die verschiedenen Segmente vorgenommen werden soll. In der Regel gibt es drei Möglichkeiten (Datenparallelität, vollständig fragmentierte Datenparallelität und Pipelineparallelität). Wir raten davon ab, Aktivierungen über die Modelldimensionen hinweg zu fragmentieren (manchmal als Tensor-Parallelität bezeichnet), da dafür zu viel Bandbreite zwischen den einzelnen Segmenten benötigt wird. Bei allen diesen Strategien können Sie dieselbe Fragmentierungsstrategie innerhalb eines Slice beibehalten, der sich in der Vergangenheit bewährt hat.

Wir empfehlen, mit der reinen Datenparallelität zu beginnen. Die Verwendung der vollständig fragmentierten Datenparallelität ist nützlich, um Arbeitsspeichernutzung freizugeben. Der Nachteil besteht darin, dass die Kommunikation zwischen den Segmenten das DCN-Netzwerk verwendet und Ihre Arbeitslast verlangsamt. Verwenden Sie die Pipelineparallelität nur, wenn dies basierend auf der Batchgröße erforderlich ist (wie unten dargestellt).

Wann sollte Datenparallelität verwendet werden?

Reine Datenparallelität funktioniert gut in Fällen, in denen eine Arbeitslast gut ausgeführt wird, Sie aber ihre Leistung durch eine Skalierung über mehrere Slices verbessern möchten.

Um eine starke Skalierung über mehrere Segmente hinweg zu erreichen, muss die für die vollständige Reduzierung über das DCN erforderliche Zeit kleiner sein als die für einen Rückwärtsdurchlauf erforderliche Zeit. DCN wird für die Kommunikation zwischen Segmenten verwendet und ist ein begrenzender Faktor für den Arbeitslastdurchsatz.

Jeder v4-TPU-Chip bietet eine Spitzenleistung von 275 × 1012 FLOPS pro Sekunde.

Pro TPU-Host gibt es vier Chips, und jeder Host hat eine maximale Netzwerkbandbreite von 50 Gbit/s.

Das bedeutet, dass die arithmetische Intensität 4 * 275 * 1012 FLOPS / 50 Gbit/s = 22.000 FLOPS/Bit beträgt.

Das Modell verwendet eine DCN-Bandbreite von 32 bis 64 Bit für jeden Parameter und Schritt. Wenn Sie zwei Slices verwenden, nutzt Ihr Modell eine DCN-Bandbreite von 32 Bit. Wenn Sie mehr als zwei Slices verwenden, führt der Compiler einen vollständigen Shuffle-Vorgang zur kompletten Reduzierung aus und Sie nutzen für jeden Parameter und Schritt bis zu 64 Bit DCN-Bandbreite. Die für jeden Parameter erforderliche Anzahl von FLOPS variiert je nach Modell. Insbesondere für Transformer-basierte Sprachmodelle beträgt die Anzahl der für einen Vorwärts- und Rückwärtstermin erforderlichen FLOPS ungefähr 6 * B * P, wobei Folgendes gilt:

  • B ist die Batchgröße in Tokens.
  • P ist die Anzahl der Parameter.

Die Anzahl der FLOPS pro Parameter beträgt 6 * B und die Anzahl der FLOPS pro Parameter während des Rückwärtsdurchlaufs ist 4 * B.

Achten Sie darauf, dass die Betriebsintensität die arithmetische Intensität der TPU-Hardware überschreitet, um eine starke Skalierung über mehrere Segmente hinweg zu gewährleisten. Sie können die Betriebsintensität berechnen, indem Sie die Anzahl der FLOPS pro Parameter während der Rückwärtsdurchführung durch die Netzwerkbandbreite (in Bit) pro Parameter und Schritt teilen: Operational Intensity = FLOPSbackwards_pass / DCN bandwidth

Wenn Sie daher bei einem Transformer-basierten Sprachmodell zwei Slices verwenden: Operational intensity = 4 * B / 32

Wenn Sie mehr als zwei Segmente verwenden: Operational intensity = 4 * B/64

Dies empfiehlt eine Mindest-Batchgröße zwischen 176.000 und 352.000 für Transformer-basierte Sprachmodelle. Da das DCN-Netzwerk kurzzeitig Pakete verwerfen kann, ist es am besten, einen erheblichen Fehlerbereich aufrechtzuerhalten. Die Datenparallelität sollte nur bereitgestellt werden, wenn die Batchgröße pro Pod mindestens 350.000 (zwei Pods) bis 700.000 (viele Pods) beträgt.

Bei anderen Modellarchitekturen müssen Sie die Laufzeit Ihrer Rückwärtsläufe pro Slice schätzen (entweder durch Timing mit einem Profiler oder durch Zählen von FLOPS). Anschließend können Sie dies mit der erwarteten Laufzeit vergleichen, um alle über DCN zu reduzieren, und eine gute Schätzung dazu erhalten, ob die Datenparallelität für Sie sinnvoll ist.

Wann sollte die vollständig fragmentierte Datenparallelität (FSDP) verwendet werden?

Die vollständig fragmentierte Datenparallelität (FSDP) kombiniert Datenparallelität (Fragmentierung der Daten über Knoten) mit der Fragmentierung der Gewichtungen über Knoten. Für jede Operation in den Vorwärts- und Rückwärtsdurchläufen werden die Gewichtungen gesammelt, sodass jeder Slice die benötigten Gewichtungen hat. Statt die Gradienten mithilfe von „all-Reduce“ zu synchronisieren, werden die Gradienten „reduziert“ beim Erzeugen gestreut. Auf diese Weise erhält jedes Slice nur die Gradienten für die Gewichtungen, für die es verantwortlich ist.

Ähnlich wie bei der Datenparallelität erfordert FSDP eine lineare Skalierung der globalen Batchgröße mit der Anzahl der Sektoren. FSDP verringert die Speicherauslastung, wenn Sie die Anzahl der Slices erhöhen. Dies liegt daran, dass die Anzahl der Gewichtungen und der Optimierungsstatus pro Slice abnimmt, allerdings auf den Preis von erhöhtem Netzwerkverkehr und der größeren Wahrscheinlichkeit einer Blockierung aufgrund eines verzögerten Kollektivs.

In der Praxis ist FSDP über Slices hinweg am besten, wenn Sie den Batch pro Slice erhöhen und mehr Aktivierungen speichern, um die Rematerialisierung während des Rückwärtsdurchlaufs zu minimieren, oder die Anzahl der Parameter in Ihrem neuronalen Netzwerk erhöhen.

Die Vorgänge zum Erfassen und Reduzieren der Gesamtdaten in FSDP funktionieren ähnlich wie die Vorgänge in DP. Sie können also wie im vorherigen Abschnitt beschrieben feststellen, ob Ihre FSDP-Arbeitslast durch die DCN-Leistung begrenzt ist.

Wann sollte die Pipelineparallelität verwendet werden?

Die Pipelineparallelität wird relevant, wenn eine hohe Leistung mit anderen Parallelitätsstrategien erzielt wird, die eine globale Batchgröße erfordern, die größer als Ihre bevorzugte maximale Batchgröße ist. Durch die Pipelineparallelität können die Segmente in einer Pipeline einen Batch „teilen“. Die Pipelineparallelität hat jedoch zwei wesentliche Nachteile:

  1. Es tritt die „Pipeline-Bubble“ auf, in der die Chips inaktiv sind, weil sie auf Daten warten.
  2. Es erfordert Mikro-Batching, das die effektive Batchgröße, die arithmetische Intensität und letztendlich die FLOP-Auslastung verringert.

Pipelineparallelität sollte nur verwendet werden, wenn die anderen Parallelitätsstrategien eine zu große globale Batchgröße erfordern. Bevor Sie die Pipelineparallelität testen, sollten Sie empirisch testen, ob die Konvergenz pro Stichprobe bei der Batchgröße verlangsamt wird, die für ein Hochleistungs-FSDP erforderlich ist. FSDP erreicht tendenziell eine höhere FLOP-Modellauslastung, aber wenn sich die Konvergenz pro Stichprobe mit zunehmender Batchgröße verlangsamt, ist die Pipelineparallelität möglicherweise immer noch die bessere Wahl. Die meisten Arbeitslasten können ausreichend große Batchgrößen tolerieren, um nicht von der Pipelineparallelität zu profitieren. Die Arbeitslast kann jedoch abweichen.

Wenn Pipelineparallelität erforderlich ist, empfehlen wir, sie mit Datenparallelität oder FSDP zu kombinieren. Auf diese Weise können Sie die Pipelinetiefe minimieren und gleichzeitig die Batchgröße pro Pipeline erhöhen, bis die DCN-Latenz weniger zu einem Faktor für den Durchsatz wird. Konkret sollten Sie bei N Slices Pipelines mit der Tiefe 2 und N/2-Replikate der Datenparallelität, Pipelines mit der Datenparallelität von Daten und N/4-Replikaten usw. in Betracht ziehen, bis der Batch pro Pipeline so groß wird, dass die DCN-Sammlungen hinter der Arithmetik hinten verborgen werden können. Dadurch wird die durch die Pipelineparallelität einhergehende Verlangsamung minimiert, während Sie gleichzeitig über das globale Batchgrößenlimit hinaus skalieren können.

Best Practices für die Verwendung in mehreren Segmenten

Laden der Daten

Während des Trainings laden wir wiederholt Batches aus einem Dataset in das Modell. Ein effizientes, asynchrones Datenladeprogramm, das den Batch auf mehrere Hosts fragmentiert, ist wichtig, um zu vermeiden, dass die TPUs überlastet werden. Im aktuellen Datenladeprogramm in MaxText wird jeder Hostlast die gleiche Teilmenge der Beispiele zugewiesen. Diese Lösung ist für Text geeignet, erfordert aber einen Reshard im Modell. Darüber hinaus bietet MaxText noch keine deterministischen Snapshots, die es dem Daten-iterator ermöglichen würden, dieselben Daten vor und nach dem vorzeitigen Beenden zu laden.

Prüfpunktausführung

Die Prüfpunktbibliothek von Orbax bietet Primitive, mit denen JAX PyTrees im lokalen Speicher oder in Google Cloud Storage geprüft werden können. Wir bieten eine Referenzintegration mit synchronem Prüfpunkt in MaxText in checkpointing.py.

Unterstützte Konfigurationen

Formen

Alle Segmente müssen die gleiche Form haben, z. B. dieselbe AcceleratorType. Heterogene Segmentformen werden nicht unterstützt.

Orchestrierung

Die Orchestrierung wird mit der GKE unterstützt. Weitere Informationen finden Sie unter TPUs in GKE.

Frameworks

Multislice unterstützt nur JAX- und PyTorch-Arbeitslasten.

Parallelität

Wir empfehlen Nutzern, Multislices mit Datenparallelität zu testen. Weitere Informationen zum Implementieren der Pipelineparallelität mit Multislice erhalten Sie von Ihrem Google Cloud-Kundenbetreuer.

Support und Feedback

Wir freuen uns über jedes Feedback! Wenn Sie Feedback geben oder Support anfordern möchten, kontaktieren Sie uns über das Cloud TPU-Support- oder Feedbackformular.