Cloud TPU-Multislice – Übersicht

Cloud TPU Multislice ist eine Full-Stack-Technologie zur Leistungsskalierung, die es einem Trainingsjob ermöglicht, mehrere TPU-Segmente in einem einzelnen Pod oder in Segmenten in mehreren Pods mit einfacher Datenparallelität zu verwenden. Mit TPU v4-Chips bedeutet dies, dass Trainingsjobs mehr als 4.096 Chips in einem einzigen Durchlauf verwenden können. Bei Trainingsjobs, die weniger als 4.096 Chips erfordern, kann ein einzelnes Slice die beste Leistung bieten. Mehrere kleinere Slices sind jedoch leichter verfügbar, was eine schnellere Startzeit ermöglicht, wenn Multislice mit kleineren Slices verwendet wird.

Mehrere Segmente skalieren die Leistung linear

Bei der Bereitstellung in Multislice-Konfigurationen kommunizieren TPU-Chips in jedem Segment über Inter-Chip-Interconnect (ICI). TPU-Chips in verschiedenen Segmenten kommunizieren durch Übertragen von Daten an CPUs (Hosts), die wiederum die Daten über das Rechenzentrumsnetzwerk (DCN) übertragen.

Datenfluss mit mehreren Segmenten

Entwickler müssen keinen Code schreiben, um die DCN-Kommunikation zwischen den Segmenten zu implementieren. Der XLA-Computr generiert diesen Code für Sie und überschneidet die Kommunikation mit der Berechnung, um maximale Leistung zu erzielen.

Konzepte

Beschleunigertyp
Die Form jedes TPU-Slice, das ein Multislice enthält. Jeder Slice in einer Anfrage mit mehreren Segmenten hat denselben Beschleunigertyp. Ein Beschleunigertyp besteht aus einem TPU-Typ (v4 oder v5e) gefolgt von der Anzahl der TensorCores. Beispielsweise gibt v4-128 eine TPU v4 mit 128 TensorCores an.
Automatische Reparatur
Wenn bei einem Slice ein Wartungsereignis, ein vorzeitiges Beenden oder ein Hardwarefehler auftritt, erstellt Cloud TPU ein neues Slice. Im seltenen Fall, dass nicht genügend Ressourcen zum Erstellen eines neuen Slice vorhanden sind, wird die Erstellung erst abgeschlossen, wenn Hardware verfügbar ist. Nachdem das neue Slice erstellt wurde, werden alle anderen Slices in der Multislice-Umgebung neu gestartet, damit das Training fortgesetzt werden kann. Mit einem korrekt konfigurierten Startskript kann das Trainingsskript ohne Eingriff des Nutzers automatisch neu gestartet werden. Dabei wird der letzte Prüfpunkt geladen und fortgesetzt.
Dataset
Die Daten, die von einem Modell für Training oder Inferenz verwendet werden.
Rechenzentrumsnetzwerke (DCN)
Ein Netzwerk mit höherer Latenz und geringerem Durchsatz (im Vergleich zu ICI), das TPU-Slices in einer Multislice-Konfiguration verbindet.
Gruppenplanung
Wenn alle TPU-Slices gemeinsam bereitgestellt werden, wird garantiert, dass entweder alle oder keine Slices erfolgreich bereitgestellt werden.
Moderator:in
Ein Host ist ein physischer Computer, auf dem VMs ausgeführt werden. Auf einem Host können maximal vier VMs gleichzeitig ausgeführt werden. Jede VM hat eine dedizierte TPU.
Inferenz
Vortrainiertes ML-Modell auf einen Host laden und Vorhersagen für Daten treffen.
Interchip Interconnect (ICI)
Interne Links mit hoher Geschwindigkeit und niedriger Latenz, die TPUs in einem TPU Pod verbinden.
Multislice
Zwei oder mehr TPU-Chip-Slices, die über DCN kommunizieren können.
Knoten
Im Multislice-Kontext bezieht sich ein Knoten auf ein einzelnes TPU-Slice. Jedem TPU-Slice in einem Multislice wird eine Knoten-ID zugewiesen.
Pod
Eine Sammlung von TPU-Chips, die über dedizierte ICI-Netzwerkschnittstellen verbunden sind. Mit einem Pod können Sie die Verarbeitungslast auf mehrere TPUs verteilen.
Ressource in der Warteschlange (QR)
Eine Darstellung von TPU-Ressourcen, mit denen eine Anfrage für eine Einzel- oder Multislice-TPU-Umgebung in die Warteschlange gestellt und verwaltet wird.
Startskript
Ein standardmäßiges Compute Engine-Startskript, das bei jedem Start oder Neustart einer VM ausgeführt wird. Für Multislice wird er in der Anfrage zur QR-Erstellung angegeben. Weitere Informationen zu Cloud TPU-Startskripts finden Sie unter TPU-Ressourcen verwalten.
TPU-Slice
Ein logischer Unterabschnitt eines TPU-Pod, der aus TPU-Chips besteht. Alle Chips in einem Slice kommunizieren über das ICI-Netzwerk miteinander.
TPU-VM
Eine virtuelle Maschine, auf der Linux ausgeführt wird und die Zugriff auf die zugrunde liegenden TPUs hat. Bei v4-TPUs hat jede TPU-VM direkten Zugriff auf vier Chips. Manchmal wird eine TPU-VM als Worker bezeichnet.
Google Tensor
Eine Datenstruktur, die zur Darstellung mehrdimensionaler Daten in einem Modell für maschinelles Lernen verwendet wird.
Tensor Processing Unit (TPU)
Der intern entwickelte Chip für die ML-Beschleunigung von Google. Sie wurden entwickelt, um schnelles und energieeffizientes Computing für wichtige ML-Aufgaben wie die Matrixmultiplikation zu ermöglichen.
Arten der Cloud TPU-Kapazität

TPUs können aus verschiedenen Kapazitätstypen erstellt werden (siehe Nutzungsoptionen unter TPU-Preise).

  • Reservierung: Ziel auf reserviertes Kontingent. Damit Sie reserviertes Kontingent verwenden können, müssen Sie eine Reservierungsvereinbarung mit Google haben. Verwenden Sie beim Erstellen Ihrer Ressourcen das Flag --reserved.
  • Spot: Zielt Kontingent auf Abruf mithilfe von Spot-VMs an. Ihre Ressourcen können vorzeitig beendet werden, um Platz für Anfragen für einen Job mit höherer Priorität zu schaffen. Verwenden Sie beim Erstellen Ihrer Ressourcen das Flag --spot.
  • On-Demand: Bezieht sich auf ein On-Demand-Kontingent, das keine Reservierung erfordert und nicht vorzeitig beendet wird. Die TPU-Anfrage wird in eine von Cloud TPU angebotene On-Demand-Kontingentwarteschlange eingereiht. Die Verfügbarkeit von Ressourcen ist nicht garantiert. Standardmäßig ausgewählt, keine Flags erforderlich.

Jetzt starten

Wenn Sie noch keine TPUs verwendet haben, installieren Sie zuerst die Google Cloud CLI und richten Sie Ihre Cloud TPU-Umgebung ein. Für die Verwendung von Multislice müssen Ihre TPU-Ressourcen als Ressourcen in der Warteschlange verwaltet werden.

Wenn Sie bereits TPU v4-Nutzer sind und eine Reservierung haben, müssen Sie Ihre Reservierung möglicherweise zu einem neuen Reservierungssystem migrieren. Weitere Informationen erhalten Sie von Ihrem Google Cloud-Kundenbetreuer.

Einleitendes Beispiel

In dieser Anleitung wird Code aus dem MaxText GitHub-Repository verwendet. MaxText ist ein leistungsstarkes, beliebig skalierbares Open-Source- und gut getestetes einfaches LLM, das in Python und Jax geschrieben wurde. MaxText wurde für ein effizientes Training auf Cloud TPU entwickelt.

Der Code in shardings.py soll Ihnen den Einstieg in das Experimentieren mit verschiedenen Parallelitätsoptionen erleichtern. Zum Beispiel Datenparallelität, vollständig fragmentierte Datenparallelität (FSDP) und Tensor-Parallelität. Der Code wird von Einzelsegmenten bis zu Umgebungen mit mehreren Segmenten skaliert.

ICI-Parallelität

ICI bezieht sich auf die Hochgeschwindigkeits-Verbindung, die die TPUs in einem einzelnen Slice verbindet. Die ICI-Fragmentierung entspricht der Fragmentierung innerhalb eines Slice. shardings.py bietet drei ICI-Parallelitätsparameter:

  • ici_data_parallelism
  • ici_fsdp_parallelism
  • ici_tensor_parallelism

Die für diese Parameter angegebenen Werte bestimmen die Anzahl der Shards für jede Parallelitätsmethode.

Diese Eingaben müssen so beschränkt werden, dass ici_data_parallelism * ici_fsdp_parallelism * ici_tensor_parallelism der Anzahl der Chips im Segment entspricht.

Die folgende Tabelle zeigt Beispielnutzereingaben für die ICI-Parallelität für die vier in v4-8 verfügbaren Chips:

ici_data_parallelism ici_fsdp_parallelism ici_tensor_parallelism
4-Wege-FSDP 1 4 1
4-Wege-Tensor-Parallelität 1 1 4
2-Wege-FSDP + 2-Wege-Tensor-Parallelität 1 2 2

Beachten Sie, dass ici_data_parallelism in den meisten Fällen auf 1 belassen sollte, da das ICI-Netzwerk schnell genug ist, um FSDP der Datenparallelität fast immer vorzuziehen.

In diesem Beispiel wird davon ausgegangen, dass Sie mit dem Ausführen von Code auf einem einzelnen TPU-Slice vertraut sind, z. B. in Berechnung auf einer Cloud TPU-VM mit JAX ausführen. Dieses Beispiel zeigt, wie shardings.py für ein einzelnes Segment ausgeführt wird.

  1. Richten Sie die Umgebung ein:

    $ gcloud auth login
    $ gcloud config set project your-project-id
    $ gcloud config set compute/zone your-zone
    
  2. Erstellen Sie SSH-Schlüssel für gcloud. Wir empfehlen, ein leeres Passwort zu lassen (drücken Sie zweimal die Eingabetaste, nachdem Sie den folgenden Befehl ausgeführt haben). Wenn Sie gefragt werden, dass die Datei google_compute_engine bereits vorhanden ist, ersetzen Sie die vorhandene Version.

    $ ssh-keygen -f ~/.ssh/google_compute_engine
    
  3. Stellen Sie die TPUs mit dem folgenden Befehl bereit:

    $ gcloud alpha compute tpus queued-resources \
    create your-qr-id \
    --accelerator-type your-accelerator-type \
    --runtime-version tpu-ubuntu2204-base \
    --node-id qr-id \
    [--reserved |--spot]
    

    Beschreibung der Befehls-Flags

    your-qr-id
    Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
    accelerator-type
    Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen.
    runtime-version
    Die [Cloud TPU-Softwareversion](/tpu/docs/supported-tpu-configurations#tpu_software_versions).
    node-id
    Die ID der TPU-Ressourcen, die als Antwort auf die QR-Anfrage erstellt wird.
    reserved
    Verwenden Sie beim Erstellen der Slices ein reserviertes Kontingent.
    best-effort
    Best-Effort-Kontingent beim Erstellen der Slices verwenden [Standardeinstellung].

    Die Google Cloud CLI unterstützt nicht alle Optionen zum Erstellen von QR-Codes, z. B. Tags. Weitere Informationen finden Sie unter Kurzantworten erstellen.

  4. Warten Sie, bis sich der QR-Code im Status ACTIVE befindet. Das bedeutet, dass sich die Worker-Knoten im Status READY befinden. Nach dem Start der QR-Bereitstellung kann es je nach Größe des QR-Codes ein bis fünf Minuten dauern, bis die Bereitstellung abgeschlossen ist. Sie können den Status einer QR-Anfrage mit dem folgenden Befehl prüfen:

    $ gcloud compute tpus queued-resources \
      list --filter=your-qr-id
    
  5. Ein v4-8-Slice hat eine einzelne TPU-VM. Stellen Sie über SSH eine Verbindung zur TPU-VM her:

    $ gcloud compute tpus tpu-vm ssh your-qr-id
    
  6. Klonen Sie MaxText (einschließlich shardings.py) auf Ihre TPU-VM.

  7. Führen Sie im MaxText-Repository-Verzeichnis das Setup-Skript aus, um JAX und andere Abhängigkeiten auf Ihrem TPU-Slice zu installieren. Das Ausführen des Installationsskripts dauert einige Minuten.

    $ bash setup.sh
    
  8. Führen Sie den folgenden Befehl aus, um shardings.py auf Ihrem TPU-Slice auszuführen.

    $ python3 pedagogical_examples/shardings.py \
      --ici_fsdp_parallelism 4 \
      --batch_size 131072 \
      --embedding_dimension 2048
    

    Sie können die Ergebnisse in den Logs sehen. Ihre TPUs sollten etwa 260 TFLOP pro Sekunde oder eine beeindruckende FLOP-Auslastung von mehr als 90%erreichen. In diesem Fall haben wir ungefähr den maximalen Batch ausgewählt, der in den hohen Bandbreitenspeicher (HBM) der TPU passt.

  9. Sie können sich statt ICI auch mit anderen Fragmentierungsstrategien vertraut machen. Sie können beispielsweise die folgende Kombination ausprobieren:

    $ python3 pedagogical_examples/shardings.py \
      --ici_tensor_parallelism 4 \
      --batch_size 131072 \
      --embedding_dimension 2048
    
  10. Löschen Sie den QR-Code und das TPU-Slice, wenn Sie fertig sind. Sie sollten diese Bereinigungsschritte in der Umgebung ausführen, in der Sie das Slice eingerichtet haben. Führen Sie dazu zuerst exit aus, um die SSH-Sitzung zu beenden. Das Löschen dauert zwei bis fünf Minuten und kann mit dem optionalen Flag --async im Hintergrund ausgeführt werden.

    $ gcloud compute tpus queued-resources
      delete your-qr-id --force (--async)
    

Multi-Slice-Fragmentierung mit DCN-Parallelität

Das Skript shardings.py verwendet drei Parameter, die die DCN-Parallelität festlegen und der Anzahl der Shards jedes Typs von Datenparallelität entsprechen:

  • dcn_data_parallelism
  • dcn_fsdp_parallelism
  • dcn_tensor_parallelism

Die Werte dieser Parameter müssen so begrenzt werden, dass dcn_data_parallelism * dcn_fsdp_parallelism * dcn_tensor_parallelism der Anzahl der Segmente entspricht.

Verwenden Sie als Beispiel für zwei Segmente --dcn_data_parallelism = 2.

dcn_data_parallelism dcn_fsdp_parallelism dcn_tensor_parallelism Anzahl der Segmente
Zweiwege-Datenparallelität 2 1 1 2

dcn_tensor_parallelism sollte immer auf 1 gesetzt sein, da das DCN nicht für eine solche Fragmentierung geeignet ist. Bei typischen LLM-Arbeitslasten auf v4-Chips sollte auch dcn_fsdp_parallelism auf 1 gesetzt werden. Daher sollte dcn_data_parallelism auf die Anzahl der Segmente festgelegt werden. Dies ist jedoch anwendungsabhängig.

Wenn Sie die Anzahl der Segmente erhöhen (vorausgesetzt, Sie halten die Segmentgröße und den Batch pro Segment konstant), erhöhen Sie die Datenparallelität.

shardings.py in einer Umgebung mit mehreren Segmenten ausführen

Sie können shardings.py in einer Multislice-Umgebung mit multihost_runner.py oder durch Ausführen von shardings.py auf jeder TPU-VM ausführen. Hier verwenden wir multihost_runner.py. Die folgenden Schritte ähneln denen unter Getting Started: Quick Experiments on Multiple Slices aus dem MaxText-Repository, mit der Ausnahme, dass wir hier shardings.py anstelle des komplexeren LLM in train.py ausführen.

Das multihost_runner.py-Tool ist für schnelle Tests optimiert, bei denen dieselben TPUs wiederholt wiederverwendet werden. Da das Skript multihost_runner.py von langlebigen SSH-Verbindungen abhängt, wird es nicht für Jobs mit langer Ausführungszeit empfohlen. Wenn Sie einen längeren Job ausführen möchten (z. B. Stunden oder Tage), empfehlen wir die Verwendung von multihost_job.py.

In dieser Anleitung verwenden wir den Begriff runner, um die Maschine anzugeben, auf der Sie das Skript multihost_runner.py ausführen. Der Begriff Worker gibt die TPU-VMs an, aus denen Ihre Slices bestehen. Sie können multihost_runner.py auf einem lokalen Computer oder einer beliebigen Compute Engine-VM ausführen, die sich im selben Projekt wie Ihre Slices befindet. Die Ausführung von multihost_runner.py auf einem Worker wird nicht unterstützt.

multihost_runner.py stellt automatisch über SSH eine Verbindung zu TPU-Workern her.

In diesem Beispiel wird shardings.py über zwei v4-16-Segmente, insgesamt vier VMs und 16 TPU-Chips, ausgeführt. Sie können das Beispiel so ändern, dass es auf mehr TPUs ausgeführt wird.

Umgebung einrichten

  1. Klonen Sie MaxText auf Ihrem Läufercomputer.

  2. Wechseln Sie zum Repository-Verzeichnis.

  3. Erstellen Sie SSH-Schlüssel für gcloud. Wir empfehlen, ein leeres Passwort zu lassen (drücken Sie zweimal die Eingabetaste, nachdem Sie den folgenden Befehl ausgeführt haben). Wenn Sie gefragt werden, dass die Datei google_compute_engine bereits vorhanden ist, wählen Sie aus, dass die vorhandene Version nicht beibehalten werden soll.

      $ ssh-keygen -f ~/.ssh/google_compute_engine
      

  4. Fügen Sie eine Umgebungsvariable hinzu, um die Anzahl der TPU-Slices auf 2 festzulegen.

      $ export SLICE_COUNT=2
      

  5. Erstellen Sie mit queued-resources create eine Multislice-Umgebung.

    Der folgende Befehl zeigt, wie Sie eine v4-Multislice-TPU erstellen. Wenn Sie v5e verwenden möchten, geben Sie für v5e einen accelerator-type (z. B. v5litepod-16) und einen runtime-version für v5e (v2-alpha-tpuv5-lite) an.

      $ gcloud alpha compute tpus queued-resources 
    create your-qr-id
    --accelerator-type=your-accelerator-type
    --runtime-version=tpu-vm-runtime-version
    --node-count=node-count
    --node-prefix=your-qr-id
    [--reserved|--spot]

    Beschreibung der Befehls-Flags

    your-qr-id
    Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
    accelerator-type
    Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen.
    runtime-version
    Die Cloud TPU-Softwareversion.
    node-count
    Die Anzahl der zu erstellenden Segmente.
    node-prefix
    Das Präfix, mit dem Namen für die einzelnen Segmente generiert werden. An das Präfix wird für jedes Segment eine Zahl angehängt. Wenn Sie beispielsweise node-prefix auf mySlice setzen, heißen die Segmente mySlice-0, mySlice-1 usw.
    reserved
    Verwenden Sie beim Erstellen der Slices ein reserviertes Kontingent.
    best-effort
    Best-Effort-Kontingent beim Erstellen der Slices verwenden [Standardeinstellung].

  6. Wenn die QR-Bereitstellung gestartet wird, kann dies je nach Größe des QR-Codes bis zu fünf Minuten dauern. Warten Sie, bis sich die Ressource in der Warteschlange (QR) im Status ACTIVE befindet. Sie können den Status einer QR-Anfrage mit dem folgenden Befehl prüfen:

    $ gcloud compute tpus queued-resources list \
    --filter=your-qr-id
    

    Die Ausgabe sollte in etwa so aussehen:

    NAME        ZONE           NODE_COUNT  ACCELERATOR_TYPE  STATE
    ...
    que-res-id  us-central2-b  4           v4-16             ACTIVE
    ...
    

    Wenden Sie sich an Ihren Google Cloud-Kundenbetreuer, wenn der QR-Status länger als 15 Minuten den Status WAITING_FOR_RESOURCES oder PROVISIONING hat.

  7. Installieren Sie die Abhängigkeiten:

    $ python3 multihost_runner.py \
      --TPU_PREFIX=your-qr-id \
      --COMMAND="bash setup.sh"
    
  8. Führen Sie shardings.py mit multihost_runner.py auf jedem Worker aus.

    $ python3 multihost_runner.py \
      --TPU_PREFIX=your-qr-id \
      --COMMAND="python3 pedagogical_examples/shardings.py \
      --dcn_data_parallelism $SLICE_COUNT \
      --ici_fsdp_parallelism 8 \
      --batch_size 131072 \
      --embedding_dimension 2048"
    

    In den Logdateien sehen Sie eine Leistung von etwa 230 TFLOPs pro Sekunde.

  9. Bereinigen Sie die TPUs und den QR-Code, wenn Sie fertig sind. Das Löschen dauert zwei bis fünf Minuten und kann mit dem optionalen Flag --async im Hintergrund ausgeführt werden.

Arbeitslast auf Multislice skalieren

Bevor Sie Ihr Modell in einer Umgebung mit mehreren Segmenten ausführen, nehmen Sie die folgenden Änderungen am Code vor:

Dies sollten die einzigen erforderlichen Codeänderungen sein, die beim Wechsel zur Multislice-Funktion erforderlich sind. Um eine hohe Leistung zu erzielen, muss das DCN auf parallelen, vollständig fragmentierten Datenachsen oder parallelen Pipelineachsen zugeordnet werden. Leistungsaspekte und Fragmentierungsstrategien werden unter Fragmentierung mit Multislice für maximale Leistung ausführlicher erörtert.

Um zu prüfen, ob Ihr Code auf alle Geräte zugreifen kann, können Sie bestätigen, dass len(jax.devices()) der Anzahl der Chips in Ihrer Multislice-Umgebung entspricht. Wenn Sie beispielsweise vier Segmente von v4-16 verwenden, haben Sie acht Chips pro Segment × 4 Slices. Daher sollte len(jax.devices()) 32 zurückgeben.

Segmentgrößen für Multislice-Umgebungen auswählen

Fügen Sie neue Slices hinzu, die dieselbe Größe wie das vorhandene Slice haben, um eine lineare Beschleunigung zu erhalten. Wenn Sie beispielsweise ein v4-512-Slice verwenden, erreicht die Multislice-Funktion etwa die doppelte Leistung, wenn Sie ein zweites v4-512-Slice hinzufügen und die globale Batchgröße verdoppeln. Weitere Informationen finden Sie unter Fragmentierung mit Multislice für maximale Leistung.

Job in mehreren Segmenten ausführen

Es gibt drei verschiedene Ansätze, um Ihre benutzerdefinierte Arbeitslast in einer Multislice-Umgebung auszuführen:

  1. Mit dem Skript multihost_runner.py für die Testausführung
  2. Mit dem Produktions-Runner-Skript multihost_job.py
  3. Manueller Ansatz

Script für die Testausführung

Das Skript multihost_runner.py verteilt den Code an eine vorhandene Multislice-Umgebung, führt den Befehl auf jedem Host aus, kopiert Ihre Logs zurück und verfolgt den Fehlerstatus jedes Befehls. Das Skript multihost_runner.py ist in der README-Datei für MaxText dokumentiert.

Da multihost_runner.py persistente SSH-Verbindungen pflegt, eignet es sich nur für mittelgroße Tests mit relativ kurzer Dauer. Sie können die Schritte in der multihost_runner.py-Anleitung an Ihre Arbeitslast- und Hardwarekonfiguration anpassen.

Skript für die Produktionsausführung

Für Produktionsjobs, die Ausfallsicherheit gegen Hardwarefehler und andere vorzeitige Beendigung benötigen, empfiehlt es sich, die API direkt in die Create Queued Resource API einzubinden. Als Arbeitsbeispiel stellen wir multihost_job.py bereit, das den Created Queued Resource API-Aufruf mit dem entsprechenden Startskript auslöst, um Ihr Training auszuführen und bei vorzeitigem Beenden fortzufahren. Das Skript multihost_job.py ist in der README-Datei für MaxText dokumentiert.

Da multihost_job.py Ressourcen für jede Ausführung bereitstellen muss, bietet sie nicht einen so schnellen Iterationszyklus wie multihost_runner.py.

Manueller Ansatz

Wir empfehlen, multihost_runner.py oder multihost_job.py zu verwenden oder anzupassen, um Ihre benutzerdefinierte Arbeitslast in Ihrer Multislice-Konfiguration auszuführen. Wenn Sie Ihre Umgebung jedoch lieber direkt mit QR-Befehlen bereitstellen und verwalten möchten, finden Sie weitere Informationen unter Multislice-Umgebung verwalten.

Multislice-Umgebung verwalten

In den folgenden Abschnitten erfahren Sie, wie Sie QR-Codes manuell bereitstellen und verwalten können, ohne die im MaxText-Repository bereitgestellten Tools zu verwenden.

Kurzantworten erstellen

Legen Sie die folgenden Umgebungsvariablen fest, bevor Sie Kapazität bereitstellen:

  $ export your-qr-id=your-queued-resource-id
  $ export PROJECT=your-project-name
  $ export ZONE=us-central2-b
  $ export NETWORK_NAME=your-network-name
  $ export SUBNETWORK_NAME=your-subnetwork-name
  $ export RUNTIME_VERSION=tpu-ubuntu2204-base
  $ export ACCELERATOR_TYPE=v4-16
  $ export SLICE_COUNT=4
  $ export STARTUP_SCRIPT="#!/bin/bash\n ..."
  $ gcloud config set project project-name
  $ gcloud config set compute/zone zone
Eingabe Beschreibung
your-qr-id Die vom Nutzer zugewiesene ID des QR-Codes.
PROJEKT Name des Google Cloud-Projekts
ZONE us-central2-b
NETWORK_NAME Name der VPC-Netzwerke.
SUBNETWORK_NAME Name des Subnetzes in VPC-Netzwerken
RUNTIME_VERSION tpu-Ubuntu2204-base
ACCELERATOR_TYPE v4-16
EXAMPLE_TAG_1, EXAMPLE_TAG_2... Tags zum Identifizieren gültiger Quellen oder Ziele für Netzwerkfirewalls
SLICE_COUNT Anzahl der Segmente. Begrenzt auf maximal 256 Segmente.
STARTUP_SCRIPT Wenn Sie der Anfrage zur Erstellung ein Startskript hinzufügen, kann jedes Mal ausgeführt werden, wenn ein TPU-Slice bereitgestellt oder neu gestartet wird und wenn das TPU-Slice repariert oder zurückgesetzt wird.

QR-Anfrage mit gcloud erstellen

$ gcloud alpha compute tpus queued-resources \
  create ${your-qr-id} \
  --project your-project-id \
  --zone your-zone \
  --node-count ${SLICE_COUNT} \
  --accelerator-type ${ACCELERATOR_TYPE} \
  --runtime-version ${RUNTIME_VERSION} \
  --network ${NETWORK_NAME} \
  --subnetwork ${SUBNETWORK_NAME} \
  --tags ${EXAMPLE_TAG_1},${EXAMPLE_TAG_2} \ --metadata=startup-script='${STARTUP_SCRIPT}'
  [--reserved|--spot]
  

Beschreibung der Befehls-Flags

your-qr-id
Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
project
Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
zone
Die Google Cloud-Zone, in der der QR-Code erstellt werden soll.
node-count
Die Anzahl der zu erstellenden Segmente.
accelerator-type
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen.
runtime-version
Die Cloud TPU-Softwareversion.
network
Der Name eines VPC-Netzwerk, an das die TPU-Ressource angehängt werden soll.
subnetwork
Der Name eines VPC-Subnetzwerks, an das die TPU-Ressource angehängt werden soll.
reserved
Verwenden Sie beim Erstellen der Slices ein reserviertes Kontingent.
spot
Verwenden Sie beim Erstellen der Slices das Kontingent für Spot-VMs.

Prüfen Sie, ob Sie das entsprechende Kontingent haben, bevor Sie --reserved, --spot oder das standardmäßige On-Demand-Kontingent auswählen. Informationen zu Kontingenttypen finden Sie in der Kontingentrichtlinie.

QR-Anfrage mit curl erstellen

Erstellen Sie eine Datei mit dem Namen queued-resource-req.json und kopieren Sie den folgenden JSON-Code hinein.

{
  "guaranteed": { "reserved": true },
  "tpu": {
    "node_spec": [
    {
      "parent": "projects/your-project-number/locations/your-zone",
        "node": {
          "accelerator_type": "accelerator-type",
          "runtime_version": "tpu-vm-runtime-version",
          "network_config": {
            "network": "your-network-name",
            "subnetwork": "your-subnetwork-name",
            "enable_external_ips": true
          },
          "tags" : ["example-tag-1"]
          "metadata": {
            "startup-script": "your-startup-script"
          }
      },
      "multi_node_params": {
        "node_count": slice-count,
        "node_id_prefix": "your-queued-resource-id"
      }
    }
    ]
  }
}
  • your-project-number – Ihre Google Cloud-Projektnummer
  • your-zone – Der Bereich, in dem Sie den QR-Code erstellen möchten
  • accelerator-type – Version und Größe eines einzelnen Slice
  • tpu-vm-runtime-version: Die TPU-VM-Laufzeitversionen
  • your-network-name – Optional ein Netzwerk, an das der QR-Code angehängt wird
  • your-subnetwork-name: Optional ein Subnetzwerk, an das der QR-Code angehängt wird
  • example-tag-1 (optional) ein beliebiger Tag-String
  • your-startup-script – ein Startskript, das ausgeführt wird, wenn der QR-Code zugewiesen wird
  • slice-count – Die Anzahl der TPU-Slices in Ihrer Multislice-Umgebung
  • your-qr-id – Die vom Nutzer angegebene ID für den QR-Code

Weitere Informationen zu allen verfügbaren Optionen finden Sie in der Dokumentation zur REST Queued Resource API.

Ersetzen Sie Folgendes, um die Spot-Kapazität zu verwenden:

"guaranteed": { "reserved": true } mit "spot": {}

Entfernen Sie die Zeile, um die standardmäßige On-Demand-Kapazität zu verwenden.

Senden Sie die Anfrage zum Erstellen des QR-Codes mit der JSON-Nutzlast:

  $ curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" -d @queuedresourcereq.json https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/queuedResources\?queued_resource_id\=your-qr-id
  • your-project-id – Ihre Google Cloud-Projekt-ID
  • your-zone – Der Bereich, in dem Sie den QR-Code erstellen möchten
  • your-qr-id – Die vom Nutzer angegebene ID für den QR-Code

Die Antwort sollte in etwa so aussehen:

{
  "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qr-guid>",
  "metadata": {
    "@type": "type.googleapis.com/google.cloud.common.OperationMetadata",
    "createTime": "2023-11-01T00:17:05.742546311Z",
    "target": "projects/<your-project-id>/locations/<your-zone>/queuedResources/<your-qa-id>",
    "verb": "create",
    "cancelRequested": false,
    "apiVersion": "v2alpha1"
  },
  "done": false
}

Verwenden Sie den GUID-Wert am Ende des Stringwerts für das Attribut name, um Informationen zur QR-Anfrage zu erhalten.

Status eines QR-Codes abrufen

Verwenden Sie den folgenden Befehl, um den Status der QR-Anfrage abzurufen:

  $ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/your-project-id/locations/your-zone/operations/operation-your-qr-guid
  • your-project-id – Ihre Google Cloud-Projekt-ID
  • your-zone – Der Bereich, in dem der QR-Code erstellt werden soll.
  • your-qr-guid: Die GUID, die in der Ausgabe der Anfrage zur Erstellung des QR-Codes auf name folgt.

Die Antwort auf diesen Befehl enthält den Status des Vorgangs:

{
  "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qa-guid>,
  "metadata": {...},
  "done": true,
  "response": {
    "@type": "type.googleapis.com/google.cloud.tpu.v2.QueuedResource",
    ...
    "state": {
      "state": "WAITING_FOR_RESOURCES"
    }
  }
}

Wenn der QR-Code erfolgreich mit ("done = true") erstellt wurde, ist der Status im Feld response entweder WAITING_FOR_RESOURCES oder FAILED. Wenn der QR-Code den Status WAITING_FOR_RESOURCES hat, wurde er in die Warteschlange gestellt und beginnt mit der Bereitstellung, sobald genügend Ressourcen vorhanden sind. Wenn der QR-Code den Status FAILED hat, wird die Fehlerursache in der Ausgabe angegeben. Weitere Informationen zu anderen möglichen Status finden Sie im Nutzerhandbuch für Ressourcen in der Warteschlange.

Nachdem der Vorgang abgeschlossen ist, kannst du die Phasen des QR-Codes mithilfe von QR-Beschreibungen überwachen.

In einem seltenen Fall kann es sein, dass der QR-Code den Status FAILED hat, während einige Segmente den Status ACTIVE haben. Löschen Sie in diesem Fall die erstellten Ressourcen und versuchen Sie es in einigen Minuten noch einmal. Sie können sich auch an das Cloud TPU-Team wenden, um das Problem zu beheben.

SSH-Verbindung herstellen und Abhängigkeiten installieren

Unter JAX-Code auf TPU-Pod-Slices ausführen wird beschrieben, wie Sie eine Verbindung zu Ihren TPU-VMs mithilfe von SSH in einem einzelnen Slice herstellen. Verwenden Sie den folgenden gcloud-Befehl, um über SSH eine Verbindung zu allen TPU-VMs in Ihrer Multislice-Umgebung herzustellen und Abhängigkeiten zu installieren:

  $ gcloud compute tpus queued-resources ssh ${your-qr-id} \
    --zone your-zone \
    --node=all \
    --worker=all \
    --command="command-to-run"
    --batch-size=4

Mit diesem gcloud-Befehl wird der angegebene Befehl über SSH an alle Worker und Knoten im QR-Code gesendet. Der Befehl wird in Vierergruppen zusammengefasst und gleichzeitig gesendet. Der nächste Batch von Befehlen wird gesendet, wenn der aktuelle Batch die Ausführung abgeschlossen hat. Wenn bei einem der Befehle ein Fehler auftritt, wird die Verarbeitung beendet und es werden keine weiteren Batches gesendet. Weitere Informationen finden Sie in der API-Referenz für Ressourcen in der Warteschlange. Wenn die Anzahl der verwendeten Slices das Threading-Limit Ihres lokalen Computers (auch Batching-Limit genannt) überschreitet, kommt es zu einem Deadlock. Angenommen, das Batchlimit auf Ihrem lokalen Rechner beträgt 64. Wenn Sie versuchen, ein Trainingsskript für mehr als 64 Slices auszuführen, z. B. 100 Slices, teilt der SSH-Befehl die Slices in Batches auf. Das Trainingsskript wird im ersten Batch mit 64 Slices ausgeführt. Es wird auf den Abschluss der Skripts gewartet, bevor es für den verbleibenden Batch von 36 Slices ausgeführt wird. Der erste Batch von 64 Slices kann jedoch erst abgeschlossen werden, wenn die restlichen 36 Slices mit der Ausführung des Skripts beginnen, was zu einem Deadlock führt.

Um dies zu verhindern, können Sie das Trainingsskript im Hintergrund auf jeder VM ausführen. Dazu hängen Sie ein kaufmännisches Und (&) an den Skriptbefehl an, den Sie mit dem Flag --command angeben. Wenn Sie dies tun, wird die Steuerung sofort an den SSH-Befehl zurückgegeben, nachdem das Trainingsskript für den ersten Batch von Segmenten gestartet wurde. Der SSH-Befehl kann dann mit der Ausführung des Trainingsskripts für den verbleibenden Batch von 36 Slices beginnen. Sie müssen Ihre stdout- und stderr-Streams über eine entsprechende Pipeline bereitstellen, wenn Sie die Befehle im Hintergrund ausführen. Um die Parallelität innerhalb desselben QR-Codes zu erhöhen, können Sie mit dem Parameter --node bestimmte Segmente auswählen.

Netzwerkeinrichtung

Führen Sie die folgenden Schritte aus, damit TPU-Slices miteinander kommunizieren können. Installieren Sie JAX auf jedem Segment. Weitere Informationen finden Sie unter JAX-Code auf TPU-Pod-Slices ausführen. Bestätigen Sie, dass len(jax.devices()) der Anzahl der Chips in Ihrer Multislice-Umgebung entspricht. Führen Sie dazu für jedes Segment folgenden Befehl aus:

  $ python3 -c 'import jax; print(jax.devices())'

Wenn Sie diesen Code auf vier Segmenten von v4–16 ausführen, gibt es acht Chips pro Slice und vier Segmente. Von jax.devices() sollten insgesamt 32 Chips (Geräte) zurückgegeben werden.

Kurzantworten auflisten

Sie können den Status Ihrer QRs mit dem Befehl queued-resources list abrufen:

$ gcloud compute tpus queued-resources list

NAME        ZONE           NODE_COUNT  ACCELERATOR_TYPE  STATE
...
que-res-id  us-central2-b  4           v4-16             ACTIVE
...

Kurzantworten beschreiben

Um die detaillierte Konfiguration und den Status eines QR-Codes anzusehen, verwende die describe QR API. Sie können diese API mit gcloud oder curl aufrufen.

mit gcloud:

$ gcloud compute tpus queued-resources describe ${your-qr-id}
...state:
 state: ACTIVE
...

mit curl:

$ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/your-project-id/locations/your-zone/queuedResources/${your-qr-id}
{
  "name": your-queued-res,
  "tpu": {
    "nodeSpec": [
      {
        ... // node 1
      },
      {
        ... // node 2
      },
      ...
    ]
  },
  ...
  "state": "ACTIVE"
}

state steht für den Status eines QR-Codes. Weitere Informationen zu den möglichen Status von QRs finden Sie unter Ressourcen in der Warteschlange.

Job in einer bereitgestellten Umgebung starten

Sie können Arbeitslasten manuell ausführen, indem Sie über SSH eine Verbindung zu allen Hosts in jedem Segment herstellen und den folgenden Befehl auf allen Hosts ausführen.

$ gcloud compute tpus tpu-vm ssh your-qr-id \
  --zone=your-zone \
  --worker=all \
  --node=all \
  --command="command-to-run"

Kurzantworten zurücksetzen

Mit der ResetQueuedResource API können alle VMs in einem ACTIVE-QR zurückgesetzt werden. Durch das Zurücksetzen der VMs wird das Löschen des Arbeitsspeichers der Maschine erzwungen und die VM auf ihren Ausgangszustand zurückgesetzt. Alle lokal gespeicherten Daten bleiben intakt und das Startskript wird nach dem Zurücksetzen aufgerufen. Die ResetQueuedResource API kann nützlich sein, wenn Sie alle TPUs neu starten möchten. Wenn beispielsweise das Training hängen bleibt und das Zurücksetzen aller VMs einfacher ist als das Debugging.

Das Zurücksetzen aller VMs erfolgt parallel. Ein ResetQueuedResource-Vorgang dauert ein bis zwei Minuten. Verwenden Sie den folgenden Befehl, um die API aufzurufen:

$ gcloud compute tpus queued-resources reset your-qr-id

QR-Codes werden gelöscht

Wenn Sie Ressourcen am Ende der Trainingssitzung freigeben möchten, löschen Sie die in der Warteschlange angegebene Ressource mit dem Flag --force. Das Löschen dauert zwei bis fünf Minuten und kann mit dem optionalen Flag --async im Hintergrund ausgeführt werden.

$ gcloud compute tpus queued-resources \
delete your-qr-id --force (--async)

Automatische Wiederherstellung nach Fehlern

Im Falle einer Unterbrechung bietet Multislice eine interventionsfreie Reparatur des betroffenen Slice und das anschließende Zurücksetzen aller Slice an. Das betroffene Slice wird durch ein neues Slice ersetzt und die verbleibenden ansonsten fehlerfreien Slices zurückgesetzt. Wenn keine Kapazität zum Zuweisen eines Ersatzsegments verfügbar ist, wird das Training beendet.

Damit das Training nach einer Unterbrechung automatisch fortgesetzt wird, müssen Sie ein Startskript angeben, das die zuletzt gespeicherten Prüfpunkte prüft und diese lädt. Das Startskript wird automatisch jedes Mal ausgeführt, wenn ein Segment neu zugewiesen oder eine VM zurückgesetzt wird. Ein Startskript geben Sie in der JSON-Nutzlast an, die Sie an die API zum Erstellen von QR-Anfragen senden.

Mit dem folgenden Startskript, das in QRs erstellen verwendet wird, können Sie Fehler automatisch wiederherstellen und das Training an Prüfpunkten fortsetzen, die während des MaxText-Trainings in einem Cloud Storage-Bucket gespeichert sind:

{
 "tpu": {
   "node_spec": [
     {
      ...
         "metadata": {
               "startup-script": "#! /bin/bash \n pwd \n runuser -l user1 -c 'cd /home/user1/MaxText && python3 MaxText/train.py MaxText/configs/base.yml run_name=run_test_failure_recovery dcn_data_parallelism=4 ici_fsdp_parallelism=8 steps=10000 save_period=10 base_output_directory='gs://user1-us-central2'' EOF"
         }
     ...
     }
   ]
 }
}

Klonen Sie das MaxText-Repository, bevor Sie dies ausprobieren.

Profilerstellung und Fehlerbehebung

Die Profilerstellung ist in Einzel-Slice- und Multi-Slice-Umgebungen identisch. Weitere Informationen finden Sie unter Profilerstellung für JAX-Programme.

Optimierte Schulungen

Fragmentierung mit Multislice für maximale Leistung

Wenn Sie in Umgebungen mit mehreren Segmenten maximale Leistung erzielen möchten, müssen Sie überlegen, wie die Segmente in mehrere Segmente aufgeteilt werden sollen. Üblicherweise gibt es drei Auswahlmöglichkeiten (Datenparallelität, vollständig fragmentierte Datenparallelität und Pipeline-Parallelität). Wir raten davon ab, Aktivierungen über die Modelldimensionen hinweg zu fragmentieren (manchmal auch als Tensor-Parallelität bezeichnet), da dafür zu viel Bandbreite zwischen den Segmenten benötigt wird. Bei allen diesen Strategien können Sie dieselbe Fragmentierungsstrategie in einem Slice belassen, mit dem sich die Vergangenheit für Sie bewährt hat.

Wir empfehlen, mit reiner Datenparallelität zu beginnen. Die Verwendung der vollständig fragmentierten Datenparallelität ist nützlich, um Arbeitsspeichernutzung freizugeben. Der Nachteil ist, dass die Kommunikation zwischen den Segmenten das DCN-Netzwerk nutzt und Ihre Arbeitslast verlangsamt. Verwenden Sie die Pipeline-Parallelität nur bei Bedarf basierend auf der Batchgröße (wie unten analysiert).

Wann sollte Datenparallelität verwendet werden?

Die reine Datenparallelität funktioniert gut, wenn Sie eine gut ausgeführte Arbeitslast haben, aber ihre Leistung durch Skalierung über mehrere Slices hinweg verbessern möchten.

Um eine starke Skalierung über mehrere Segmente hinweg zu erzielen, muss die für die Durchführung der vollständigen Reduzierung über DCN erforderliche Zeit geringer sein als die für eine Rückwärtsdurchführung. DCN wird für die Kommunikation zwischen Segmenten verwendet und ist ein begrenzender Faktor für den Arbeitslastdurchsatz.

Jeder v4-TPU-Chip hat eine maximale Leistung von 275 × 1012 FLOPS pro Sekunde.

Es gibt vier Chips pro TPU-Host und jeder Host hat eine maximale Netzwerkbandbreite von 50 Gbit/s.

Das bedeutet, dass die arithmetische Intensität 4 × 275 × 1012 FLOPS ÷ 50 Gbit / s = 22.000 FLOPS / Bit beträgt.

Ihr Modell verwendet 32 bis 64 Bit DCN-Bandbreite für jeden Parameter und Schritt. Wenn Sie zwei Segmente verwenden, benötigt Ihr Modell 32 Bit DCN-Bandbreite. Wenn Sie mehr als zwei Slices verwenden, führt der Compiler einen vollständigen Shuffle All-Reduce-Vorgang aus und Sie nutzen bis zu 64 Bit DCN-Bandbreite für jeden Parameter pro Schritt. Wie viele FLOPS für jeden Parameter erforderlich sind, hängt von Ihrem Modell ab. Insbesondere bei Transformer-basierten Sprachmodellen beträgt die Anzahl der FLOPS, die für eine Vorwärts- und Rückwärtsfahrt erforderlich sind, etwa 6 * B * P, wobei Folgendes gilt:

  • B ist die Batchgröße in Tokens.
  • P ist die Anzahl der Parameter.

Die Anzahl der FLOPS pro Parameter ist 6 * B und die Anzahl der FLOPS pro Parameter beim Rückwärtsdurchlauf 4 * B.

Damit eine starke Skalierung über mehrere Segmente hinweg gewährleistet ist, muss die Betriebsintensität die arithmetische Intensität der TPU-Hardware überschreiten. Zum Berechnen der Betriebsintensität teilen Sie die Anzahl der FLOPS pro Parameter während der Rückwärtsdurchführung durch die Netzwerkbandbreite (in Bit) pro Parameter und Schritt: Operational Intensity = FLOPSbackwards_pass / DCN bandwidth

Daher sollten Sie bei einem Transformer-basierten Sprachmodell zwei Segmente verwenden: Operational intensity = 4 * B / 32

Wenn Sie mehr als zwei Segmente verwenden: Operational intensity = 4 * B/64

Daraus ergibt sich eine Mindest-Batchgröße von 176.000 bis 352.000 für Transformer-basierte Sprachmodelle. Da das DCN-Netzwerk Pakete kurzzeitig verwerfen kann, sollte ein beträchtlicher Fehlerbereich eingehalten werden. Stellen Sie die Datenparallelität nur bereit, wenn die Batchgröße pro Pod mindestens 350.000 (zwei Pods) bis 700.000 (viele Pods) beträgt.

Bei anderen Modellarchitekturen müssen Sie die Laufzeit Ihrer Rückwärtstermine pro Segment abschätzen (entweder durch die zeitliche Abfolge mit einem Profiler oder durch Zählen von FLOPS). Anschließend können Sie dies mit der erwarteten Laufzeit vergleichen, um die Reduzierung über das DCN zu reduzieren, und eine gute Einschätzung davon erhalten, ob Datenparallelität für Sie sinnvoll ist.

Wann sollte die vollständig fragmentierte Datenparallelität (FSDP) verwendet werden?

Die vollständig fragmentierte Datenparallelität (Fully Sharded Data Parallelism (FSDP)) kombiniert die Datenparallelität (Fragmentierung der Daten über Knoten hinweg) mit der Fragmentierung der Gewichtungen über Knoten hinweg. Für jede Operation in den Vorwärts- und Rückwärtsdurchläufen werden die Gewichtungen vollständig erfasst, sodass jedes Stück über die benötigten Gewichte verfügt. Anstatt die Gradienten mit „All-Reduce“ zu synchronisieren, werden die Gradienten beim Erstellen redundant verteilt. Auf diese Weise erhält jedes Segment nur die Gradienten für die Gewichtungen, für die es verantwortlich ist.

Ähnlich wie bei der Datenparallelität erfordert FSDP die lineare Skalierung der globalen Batchgröße anhand der Anzahl der Segmente. FSDP verringert die Speicherauslastung, wenn Sie die Anzahl der Segmente erhöhen. Dies liegt daran, dass die Anzahl der Gewichtungen und der Optimierungsstatus pro Segment abnimmt, aber dies zum Preis eines erhöhten Netzwerktraffics und einer größeren Wahrscheinlichkeit einer Blockierung aufgrund eines verzögerten Kollektivs.

In der Praxis ist FSDP segmentübergreifend am besten geeignet, wenn Sie den Batch pro Slice erhöhen, mehr Aktivierungen speichern, um die Re-Materialisierung während des Rückwärtsdurchlaufs zu minimieren, oder wenn Sie die Anzahl der Parameter in Ihrem neuronalen Netzwerk erhöhen.

Die Vorgänge „Alle Datenerfassung“ und „Alle Reduzierung“ in FSDP funktionieren ähnlich wie die in DP. Sie können also wie im vorherigen Abschnitt beschrieben feststellen, ob Ihre FSDP-Arbeitslast durch die DCN-Leistung eingeschränkt ist.

Wann sollte die Pipeline-Parallelität verwendet werden?

Die Pipeline-Parallelität wird relevant, wenn Sie mit anderen Parallelitätsstrategien, die eine globale Batchgröße erfordern, die Ihre bevorzugte maximale Batchgröße übersteigt, hohe Leistung erzielen. Durch die Pipeline-Parallelität können die Segmente, die eine Pipeline enthalten, einen Batch "teilen". Die Pipeline-Parallelität hat jedoch zwei wesentliche Nachteile:

  1. Dabei wird ein „Pipeline-Infofeld“ angezeigt, in dem Chips inaktiv sind, weil sie auf Daten warten.
  2. Es erfordert Mikro-Batching, was die effektive Batchgröße, die arithmetische Intensität und schließlich die FLOPS-Auslastung reduziert.

Die Pipeline-Parallelität sollte nur verwendet werden, wenn die anderen Parallelitätsstrategien eine zu große globale Batchgröße erfordern. Bevor Sie die Pipeline-Parallelität testen, sollten Sie empirisch feststellen, ob sich die Konvergenz pro Stichprobe bei der Batchgröße verlangsamt, die für ein leistungsstarkes FSDP erforderlich ist. FSDP erreicht tendenziell eine höhere FLOP-Auslastung des Modells. Wenn sich die Konvergenz pro Stichprobe jedoch mit zunehmender Batchgröße verlangsamt, ist die Pipeline-Parallelität möglicherweise immer noch die bessere Wahl. Die meisten Arbeitslasten können ausreichend große Batchgrößen tolerieren, um nicht von der Pipeline-Parallelität zu profitieren. Ihre Arbeitslast kann jedoch abweichen.

Wenn eine Pipeline-Parallelität erforderlich ist, empfehlen wir eine Kombination mit Datenparallelität oder FSDP. Auf diese Weise können Sie die Pipelinetiefe minimieren und gleichzeitig die Batchgröße pro Pipeline erhöhen, bis die DCN-Latenz einen Faktor für den Durchsatz abnimmt. Konkret sollten Sie bei N-Segmenten Pipelines der Tiefe 2 und N/2-Replikate der Datenparallelität, dann Pipelines der Tiefe 4 und N/4-Replikate der Datenparallelität usw. berücksichtigen, bis der Batch pro Pipeline so groß wird, dass die DCN-Sammlungen im Rückwärtsdurchgang hinter der Arithmetik verborgen werden können. Dadurch wird die durch die Pipeline-Parallelität verursachte Verlangsamung minimiert und Sie können über das globale Limit für die Batchgröße hinaus skalieren.

Best Practices für mehrere Segmente

Laden der Daten

Während des Trainings werden wiederholt Batches aus einem Dataset geladen, um sie in das Modell zu laden. Ein effizientes, asynchrones Datenladeprogramm, das den Batch auf mehrere Hosts aufteilt, ist wichtig, um zu vermeiden, dass die TPUs nicht überbelastet werden. Das aktuelle Datenladeprogramm in MaxText hat für jede Hostlast eine gleiche Teilmenge der Beispiele. Diese Lösung ist für Text geeignet, erfordert aber einen Reshard innerhalb des Modells. Außerdem bietet MaxText noch kein deterministisches Erstellen von Snapshots, mit dem der Daten-Iterator vor und nach dem vorzeitigen Beenden dieselben Daten laden kann.

Prüfpunktausführung

Die Orbax-Prüfpunktbibliothek bietet Primitive für die Prüfpunktausführung von JAX-PyTrees auf lokalen Speicher oder Google Cloud-Speicher. Wir bieten in checkpointing.py eine Referenzintegration mit einer synchronen Prüfpunktausführung in MaxText.

Unterstützte Konfigurationen

Formen

Alle Segmente müssen dieselbe Form haben (z. B. dieselbe AcceleratorType). Heterogene Segmentformen werden nicht unterstützt.

Orchestrierung

Die Orchestrierung wird mit GKE unterstützt. Weitere Informationen finden Sie unter TPUs in GKE.

Frameworks

Multislice unterstützt nur JAX- und PyTorch-Arbeitslasten.

Parallelität

Wir empfehlen Nutzern, „Multislice“ mit Datenparallelität zu testen. Weitere Informationen zur Implementierung der Pipeline-Parallelität mit Multislice erhalten Sie von Ihrem Google Cloud-Kundenbetreuer.

Support und Feedback

Wir freuen uns über jedes Feedback! Wenn Sie Feedback geben oder Support anfordern möchten, kontaktieren Sie uns über das Support- oder Feedbackformular für Cloud TPU.