Cloud TPU-Multislice – Übersicht

Cloud TPU Multislice ist eine Full-Stack-Technologie zur Leistungsskalierung. mit dem ein Trainingsjob mehrere TPU-Slices in einem einzelnen Pod oder auf in mehreren Pods mit einfacher Datenparallelität. Bei TPU v4-Chips bedeutet, dass Trainingsjobs mehr als 4.096 Chips in einem einzigen Durchlauf verwenden können. Für Schulungen bei denen weniger als 4.096 Chips benötigt werden, die Leistung. Mehrere kleinere Segmente sind jedoch leichter verfügbar, was eine schnellere Startzeit ermöglicht, wenn Multislice mit kleineren Segmente.

Mehrere Segmente skalieren die Leistung linear

Bei Bereitstellung in Multislice-Konfigurationen werden TPU-Chips in jedem Segment bereitgestellt. über Inter-Chip-Interconnect (ICI) kommunizieren können. TPU-Chips in verschiedenen -Slices kommunizieren, indem sie Daten an CPUs (Hosts) übertragen, die wiederum Daten über das Netzwerk des Rechenzentrums (DCN) übertragen.

Datenfluss mit mehreren Segmenten

Entwickler müssen keinen Code schreiben, um die DCN-Kommunikation zwischen den Segmenten zu implementieren. Der XLA-Compiler generiert diesen Code für dich und überschneidet die Kommunikation mit für maximale Leistung berechnet.

Konzepte

Beschleunigertyp
Die Form jedes TPU-Slice, das ein Multislice enthält. Jedes in einer Anfrage mit mehreren Segmenten denselben Beschleunigertyp hat. Beschleuniger besteht aus einem TPU-Typ (v4 oder v5e) gefolgt von der Anzahl TensorCores Beispielsweise gibt v4-128 eine TPU v4 mit 128 TensorCores an.
Automatische Reparatur
Wenn bei einem Slice ein Wartungsereignis, eine vorzeitige Beendigung oder ein Hardwarefehler auftritt, Cloud TPU erstellt ein neues Slice. Im seltenen Fall, dass nicht genügend Ressourcen zum Erstellen eines neuen Slice vorhanden sind, wird die Erstellung nicht abgeschlossen. bis Hardware verfügbar ist. Nachdem das neue Segment erstellt wurde, in der Multislice-Umgebung neu gestartet. Mit einem korrekt konfigurierten Startskript können ohne Eingriff des Nutzers automatisch neu gestartet werden. ab dem letzten Checkpoint.
Dataset
Die Daten, die von einem Modell für Training oder Inferenz verwendet werden.
Rechenzentrumsnetzwerke (DCN)
Ein Netzwerk mit höherer Latenz und geringem Durchsatz (im Vergleich zu ICI), das verbindet TPU-Slices in einer Multislice-Konfiguration.
Gruppenplanung
Wenn alle TPU-Slices gemeinsam bereitgestellt werden, Entweder wurden alle oder keine der Slices bereitgestellt.
Moderator:in
Ein Host ist ein physischer Computer, auf dem VMs ausgeführt werden. Ein Host kann höchstens vier VMs ausführen auf einmal. Jede VM hat eine dedizierte TPU.
Inferenz
Laden Sie ein vortrainiertes Modell für maschinelles Lernen auf einen Host und treffen Sie Vorhersagen für Daten.
Interchip Interconnect (ICI)
Interne Links mit hoher Geschwindigkeit und niedriger Latenz, die TPUs in einem TPU Pod verbinden.
Multislice
Zwei oder mehr TPU-Chip-Slices, die über DCN kommunizieren können.
Knoten
Im Multislice-Kontext bezieht sich ein Knoten auf ein einzelnes TPU-Slice. Jedes TPU-Slices in einem Multislice wird eine Knoten-ID zugewiesen.
Pod
Eine Sammlung von TPU-Chips, die über dedizierte ICI-Netzwerkschnittstellen verbunden sind. A Mit einem Pod können Sie die Verarbeitungslast auf mehrere TPUs verteilen.
Ressource in der Warteschlange (QR)
Eine Darstellung von TPU-Ressourcen, mit denen eine Anfrage für ein Element in die Warteschlange gestellt und verwaltet wird Einzel- oder Multislice-TPU-Umgebung.
Startskript
Ein standardmäßiges Compute Engine-Startskript die jedes Mal ausgeführt wird, wenn eine VM gestartet oder neu gestartet wird. Für „Multislice“ in der Anfrage zur Erstellung des QR-Codes angegeben. Weitere Informationen Informationen zu Cloud TPU-Startskripts finden Sie unter TPU-Ressourcen verwalten.
TPU-Slice
Ein logischer Unterabschnitt eines TPU-Pod, der aus TPU-Chips besteht. Alle Chips in einem Slice über das ICI-Netzwerk miteinander kommunizieren.
TPU-VM
Eine virtuelle Maschine, auf der Linux ausgeführt wird und die Zugriff auf die zugrunde liegenden TPUs hat. Für v4-TPUs hat jede TPU-VM direkten Zugriff auf vier Chips. Manchmal nennen wir eine TPU VM als Worker
Tensor
Eine Datenstruktur, die zur Darstellung mehrdimensionaler Daten in einer Maschine verwendet wird Lernmodells.
Tensor Processing Unit (TPU)
Der intern entwickelte Chip für die ML-Beschleunigung von Google. Sie wurden entwickelt, um schnelles und energieeffizientes Computing für wichtige ML-Aufgaben wie Matrixmultiplikation.
Arten der Cloud TPU-Kapazität

TPUs können aus verschiedenen Kapazitätstypen erstellt werden (siehe Nutzungsoptionen in TPU-Preise)

  • Reservierung: Ziel auf reserviertes Kontingent. Zur Verwendung reservierter Kontingente benötigen Sie Folgendes: Reservierungsvereinbarung mit Google. Verwenden Sie beim Erstellen das Flag --reserved. Ihre Ressourcen.
  • Spot: Targeting auf Kontingent auf Abruf mithilfe von Spot-VMs. Ihr werden unter Umständen vorzeitig beendet, um Platz für Anfragen für eine höhere priorisierten Job. Verwenden Sie beim Erstellen Ihrer Ressourcen das Flag --spot.
  • On-Demand: Gezielt auf ein On-Demand-Kontingent, das keine Reservierung erfordert und nicht vorzeitig beendet. Die TPU-Anfrage wird in eine On-Demand- von Cloud TPU angebotene Kontingentwarteschlange, ist die Verfügbarkeit von Ressourcen garantiert. Standardmäßig ausgewählt, keine Flags erforderlich.

Mehr erfahren

Wenn Sie noch keine TPUs verwendet haben, installieren Sie zuerst die Google Cloud CLI. und richten Sie Ihre Cloud TPU-Umgebung ein. Zur Verwendung Multislice verwenden, müssen Ihre TPU-Ressourcen als Ressourcen in der Warteschlange verwaltet werden.

Wenn Sie bereits TPU v4-Nutzer sind und eine Reservierung haben, müssen Sie möglicherweise migrieren in ein neues Reservierungssystem übertragen. Weitere Informationen wenden Sie sich an Ihren Google Cloud-Kundenbetreuer.

Einleitendes Beispiel

In dieser Anleitung wird Code aus dem MaxText GitHub-Repository verwendet. MaxText ist ein leistungsstarkes, beliebig skalierbares Open-Source-Projekt, das sich umfassend getestet hat in Python und Jax geschrieben. MaxText wurde für ein effizientes Training auf Cloud TPU

Code in shardings.py soll Ihnen den Einstieg in das Experimentieren mit verschiedenen Parallelität Optionen. Zum Beispiel Datenparallelität, vollständig fragmentierte Datenparallelität (Full Sharded Data Parallelism, FSDP), und Tensor-Parallelität. Der Code wird vom Einzelsegment zu „Multislice“ skaliert Umgebungen.

ICI-Parallelität

ICI bezieht sich auf die Hochgeschwindigkeits-Verbindung, die die TPUs in einem einzigen Segment. Die ICI-Fragmentierung entspricht der Fragmentierung innerhalb eines Slice. shardings.py stellt drei ICI-Parallelitätsparameter bereit:

  • ici_data_parallelism
  • ici_fsdp_parallelism
  • ici_tensor_parallelism

Die Werte, die Sie für diese Parameter angeben, bestimmen die Anzahl der Shards für der einzelnen Parallelitätsmethoden.

Diese Eingaben müssen begrenzt werden, damit ici_data_parallelism * ici_fsdp_parallelism * ici_tensor_parallelism ist gleich der Anzahl der Chips im Segment.

Die folgende Tabelle zeigt Beispielnutzereingaben für die ICI-Parallelität für die vier In v4-8 verfügbare Chips:

ici_data_parallelism ici_fsdp_parallelism ici_tensor_parallelism
4-Wege-FSDP 1 4 1
4-Wege-Tensor-Parallelität 1 1 4
2-Wege-FSDP + 2-Wege-Tensor-Parallelität 1 2 2

Beachten Sie, dass ici_data_parallelism in den meisten Fällen auf 1 belassen sollte, da der Wert Das ICI-Netzwerk ist schnell genug, um FSDP fast immer gegenüber Datenparallelität zu bevorzugen.

In diesem Beispiel wird davon ausgegangen, dass Sie mit dem Ausführen von Code auf einem einzelnen TPU-Slice vertraut sind. Beispiel: Berechnung auf einer Cloud TPU VM mit JAX ausführen Dieses Beispiel zeigt, wie shardings.py für ein einzelnes Segment ausgeführt wird.

  1. Richten Sie die Umgebung ein:

    $ gcloud auth login
    $ gcloud config set project your-project-id
    $ gcloud config set compute/zone your-zone
    
  2. Erstellen Sie SSH-Schlüssel für gcloud. Wir empfehlen, das Passwort leer zu lassen (drücke und Eingabetaste, nachdem Sie den folgenden Befehl ausgeführt haben. Wenn Sie aufgefordert werden, Die Datei google_compute_engine ist bereits vorhanden. Ersetzen Sie die vorhandene Version.

    $ ssh-keygen -f ~/.ssh/google_compute_engine
    
  3. Stellen Sie die TPUs mit dem folgenden Befehl bereit:

    $ gcloud compute tpus queued-resources \
    create your-qr-id \
    --accelerator-type your-accelerator-type \
    --runtime-version tpu-ubuntu2204-base \
    --node-id qr-id \
    [--reserved |--spot]
    

    Beschreibung der Befehls-Flags

    your-qr-id
    Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
    accelerator-type
    Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen.
    runtime-version
    Die [Cloud TPU-Softwareversion](/tpu/docs/supported-tpu-configurations#tpu_software_versions).
    node-id
    Die ID der TPU-Ressourcen, die als Reaktion auf die QR-Anfrage.
    reserved
    Verwenden Sie beim Erstellen der Slices ein reserviertes Kontingent.
    spot
    Verwenden Sie beim Erstellen der Slices das Kontingent für Spot-VMs.

    Die Google Cloud CLI unterstützt nicht alle Optionen zum Erstellen von QR-Codes, z. B. Tags. Weitere Informationen finden Sie unter Kurzantworten erstellen.

  4. Warten Sie, bis sich der QR-Code im Status ACTIVE befindet. Das bedeutet, dass sich die Worker-Knoten den Status READY haben. Nach Beginn der QR-Code-Bereitstellung kann es ein bis fünf bis die Größe des QR-Codes abgestimmt ist. Sie können den Status mit dem folgenden Befehl erstellen:

    $ gcloud compute tpus queued-resources \
      list --filter=your-qr-id
    
  5. Ein v4-8-Slice hat eine einzelne TPU-VM. Stellen Sie über SSH eine Verbindung zur TPU-VM her:

    $ gcloud compute tpus tpu-vm ssh your-qr-id
    
  6. Klonen Sie MaxText (einschließlich shardings.py) auf Ihre TPU-VM.

  7. Führen Sie zur Installation das Setup-Skript im MaxText-Repository-Verzeichnis aus JAX und andere Abhängigkeiten von Ihrem TPU-Slice Das Einrichtungsskript Minuten gelaufen.

    $ bash setup.sh
    
  8. Führen Sie den folgenden Befehl aus, um shardings.py auf Ihrem TPU-Slice auszuführen.

    $ python3 pedagogical_examples/shardings.py \
      --ici_fsdp_parallelism 4 \
      --batch_size 131072 \
      --embedding_dimension 2048
    

    Sie können die Ergebnisse in den Logs sehen. Ihre TPUs sollten etwa 260 TFLOP erreichen pro Sekunde oder eine beeindruckende FLOP-Auslastung von über 90 %! In diesem Fall haben wir ungefähr den maximalen Batch ausgewählt, der in den hohen Wert der TPU passt Bandwidth Memory (HBM) (Bandbreitenspeicher).

  9. Sie können sich gerne andere Fragmentierungsstrategien ansehen. statt ICI zu verwenden. Sie könnten z. B. die folgende Kombination ausprobieren:

    $ python3 pedagogical_examples/shardings.py \
      --ici_tensor_parallelism 4 \
      --batch_size 131072 \
      --embedding_dimension 2048
    
  10. Löschen Sie den QR-Code und das TPU-Slice, wenn Sie fertig sind. Sie sollten diese Bereinigung ausführen Schritte aus der Umgebung, in der Sie das Slice eingerichtet haben (führen Sie zuerst exit aus, beenden Sie die SSH-Sitzung. Das Löschen dauert zwei bis fünf Minuten. und kann mit dem optionalen Flag --async im Hintergrund ausgeführt werden.

    $ gcloud compute tpus queued-resources
      delete your-qr-id --force (--async)
    

Multi-Slice-Fragmentierung mit DCN-Parallelität

Das Skript shardings.py verwendet drei Parameter, die die DCN-Parallelität festlegen. entspricht der Anzahl der Shards jedes Typs von Datenparallelität:

  • dcn_data_parallelism
  • dcn_fsdp_parallelism
  • dcn_tensor_parallelism

Die Werte dieser Parameter müssen eingeschränkt werden, dcn_data_parallelism * dcn_fsdp_parallelism * dcn_tensor_parallelism ist gleich Anzahl der Slices.

Verwenden Sie als Beispiel für zwei Segmente --dcn_data_parallelism = 2.

dcn_data_parallelism dcn_fsdp_parallelism dcn_tensor_parallelism Anzahl der Segmente
Zweiwege-Datenparallelität 2 1 1 2

dcn_tensor_parallelism sollte immer auf 1 gesetzt werden, da das DCN schlecht ist. für eine solche Fragmentierung geeignet. Bei typischen LLM-Arbeitslasten auf v4-Chips dcn_fsdp_parallelism sollte auch auf 1 festgelegt werden, damit dcn_data_parallelism sollte auf die Anzahl der Segmente festgelegt werden, anwendungsabhängig.

Wenn Sie die Anzahl der Segmente erhöhen (vorausgesetzt, Sie behalten die Segmentgröße und den Batch pro Slice-Konstante), erhöhen Sie die Datenparallelität.

shardings.py in einer Umgebung mit mehreren Segmenten ausführen

Sie können shardings.py in einer Umgebung mit mehreren Segmenten mit multihost_runner.py oder durch Ausführen von shardings.py auf jeder TPU-VM. Hier verwenden wir multihost_runner.py Die folgenden Schritte ähneln denen, Erste Schritte: Schnelltests mit mehreren Bereichen aus dem MaxText-Repository, außer hier führen wir shardings.py anstelle der ein komplexeres LLM in train.py erstellen.

Das multihost_runner.py-Tool ist für schnelle, wiederholte Tests optimiert TPUs wiederverwenden können. Da das Skript multihost_runner.py von langlebigen SSH-Verbindungen verwenden möchten, empfehlen wir diese Methode nicht für Jobs mit langer Ausführungszeit. Wenn Sie einen längeren Job ausführen möchten (z. B. Stunden oder Tage), empfehlen wir Ihnen, multihost_job.py verwenden.

In dieser Anleitung geben wir mit dem Begriff runner die Maschine an, auf der Sie führen Sie das Skript multihost_runner.py aus. Wir verwenden den Begriff Worker für TPU-VMs, aus denen Ihre Slices bestehen. Sie können multihost_runner.py auf einem lokalen oder eine beliebige Compute Engine-VM im selben Projekt wie Ihre Slices. Laufen multihost_runner.py für einen Worker wird nicht unterstützt.

multihost_runner.py stellt automatisch über SSH eine Verbindung zu TPU-Workern her.

In diesem Beispiel wird shardings.py für zwei v4-16-Slices ausgeführt, insgesamt vier VMs und 16 TPU-Chips Sie können das Beispiel so ändern, dass es auf mehr TPUs ausgeführt wird.

Umgebung einrichten

  1. MaxText im Runner klonen Maschine.

  2. Wechseln Sie zum Repository-Verzeichnis.

  3. Erstellen Sie SSH-Schlüssel für gcloud. Wir empfehlen, das Passwort leer zu lassen (drücken Sie und Eingabetaste, nachdem Sie den folgenden Befehl ausgeführt haben. Wenn Sie aufgefordert werden, Die Datei google_compute_engine ist bereits vorhanden. Wählen Sie aus, dass Sie Ihre vorhandenen Version.

      $ ssh-keygen -f ~/.ssh/google_compute_engine
      

  4. Fügen Sie eine Umgebungsvariable hinzu, um die Anzahl der TPU-Slices auf 2 festzulegen.

      $ export SLICE_COUNT=2
      

  5. Erstellen Sie mit queued-resources create eine Multislice-Umgebung.

    Der folgende Befehl zeigt, wie Sie eine v4-Multislice-TPU erstellen. Zur Verwendung v5e einen accelerator-type für v5e (z. B. v5litepod-16) und den Version 5e runtime-version (v2-alpha-tpuv5-lite).

      $ gcloud compute tpus queued-resources 
    create your-qr-id
    --accelerator-type=your-accelerator-type
    --runtime-version=tpu-vm-runtime-version
    --node-count=node-count
    --node-prefix=your-qr-id
    [--reserved|--spot]

    Beschreibung der Befehls-Flags

    your-qr-id
    Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
    accelerator-type
    Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen.
    runtime-version
    Die Cloud TPU-Softwareversion.
    node-count
    Die Anzahl der zu erstellenden Segmente.
    node-prefix
    Das Präfix, mit dem Namen für die einzelnen Segmente generiert werden. Eine Zahl wird angehängt bis zum Präfix hinzu. Beispiel: Sie legen node-prefix fest. auf mySlice gesetzt haben, haben die Segmente den Namen: mySlice-0, mySlice-1, wobei für jedes Segment numerisch fortgesetzt wird.
    reserved
    Verwenden Sie beim Erstellen der Slices ein reserviertes Kontingent.
    spot
    Verwenden Sie beim Erstellen der Slices das Kontingent für Spot-VMs.

  6. Wenn die QR-Bereitstellung gestartet wird, kann es bis zu fünf Minuten dauern, bis die Bereitstellung abgeschlossen ist je nach Größe des QR-Codes. Warten Sie, bis sich die Ressource in der Warteschlange (QR) in der Warteschlange befindet ACTIVE. Sie können den Status einer QR-Anfrage mit der folgenden Befehl:

    $ gcloud compute tpus queued-resources list \
    --filter=your-qr-id
    

    Die Ausgabe sollte in etwa so aussehen:

    NAME        ZONE           NODE_COUNT  ACCELERATOR_TYPE  STATE
    ...
    que-res-id  us-central2-b  4           v4-16             ACTIVE
    ...
    

    Wenden Sie sich an Ihren Google Cloud-Kundenbetreuer, wenn der QR-Status in der WAITING_FOR_RESOURCES- oder PROVISIONING-Status für mehr als 15 Minuten.

  7. Installieren Sie die Abhängigkeiten:

    $ python3 multihost_runner.py \
      --TPU_PREFIX=your-qr-id \
      --COMMAND="bash setup.sh"
    
  8. Führen Sie shardings.py mit multihost_runner.py auf jedem Worker aus.

    $ python3 multihost_runner.py \
      --TPU_PREFIX=your-qr-id \
      --COMMAND="python3 pedagogical_examples/shardings.py \
      --dcn_data_parallelism $SLICE_COUNT \
      --ici_fsdp_parallelism 8 \
      --batch_size 131072 \
      --embedding_dimension 2048"
    

    Im Protokoll werden ungefähr 230 TFLOPs pro Sekunde der Leistung angezeigt. -Dateien.

  9. Bereinigen Sie die TPUs und den QR-Code, wenn Sie fertig sind. Das Löschen dauert zwei bis fünf Minuten abgeschlossen und kann im Hintergrund ausgeführt werden. Flag --async.

Arbeitslast auf Multislice skalieren

Bevor Sie Ihr Modell in einer Umgebung mit mehreren Segmenten ausführen, erstellen Sie folgenden Codeänderungen:

Dies sollten die einzigen erforderlichen Codeänderungen sein, die beim Wechsel zur Multislice-Funktion erforderlich sind. Um eine hohe Leistung zu erzielen, muss das DCN parallel und vollständig den Daten zugeordnet werden. parallele Achsen fragmentierter Daten oder parallele Pipelineachsen. Leistungsaspekte und Fragmentierungsstrategien werden ausführlicher Fragmentierung mit Multislice für maximale Leistung.

Sie können bestätigen, dass Ihr Code auf alle Geräte zugreifen kann, indem Sie bestätigen, len(jax.devices()) entspricht der Anzahl der Chips im Multislice zu verbessern. Wenn Sie beispielsweise vier Segmente von v4-16 verwenden, haben Sie acht Chips pro Segment × 4 Slices, sodass len(jax.devices()) 32 zurückgeben sollte.

Segmentgrößen für Multislice-Umgebungen auswählen

Fügen Sie für eine lineare Beschleunigung neue Segmente hinzu, die dieselbe Größe wie die Segment. Wenn Sie beispielsweise ein v4-512-Slice verwenden, Eine etwa doppelt so hohe Leistung erzielen, indem ein zweites v4-512-Segment hinzugefügt wird und die globale Batchgröße zu verdoppeln. Weitere Informationen finden Sie unter Fragmentierung mit Multislice für maximale Leistung.

Job in mehreren Segmenten ausführen

Es gibt drei verschiedene Ansätze, um Ihre benutzerdefinierte Arbeitslast in einer Umgebung mit mehreren Segmenten:

  1. Mit dem Skript multihost_runner.py für die Testausführung
  2. Mit dem Produktions-Runner-Skript multihost_job.py
  3. Manueller Ansatz

Script für die Testausführung

Die multihost_runner.py Das Skript verteilt den Code an eine vorhandene Multislice-Umgebung und führt Ihren Befehl auf jedem Host ausführen, Ihre Protokolle zurückkopieren und die Fehler der einzelnen Befehle nachverfolgen. Status. Das Skript multihost_runner.py ist dokumentiert in README-Datei für MaxText.

Da multihost_runner.py persistente SSH-Verbindungen unterhält, ist es nur eignet sich für relativ kurze Tests. Sie können Passen Sie die Schritte in der multihost_runner.py-Anleitung an. an Ihre Arbeitslast und Hardwarekonfiguration anpassen.

Skript für die Produktionsausführung

Für Produktionsjobs, die Resilienz gegen Hardwarefehler und andere vorzeitigem Beenden zu deaktivieren, empfiehlt es sich, eine direkte Einbindung in die Ressource vom Typ „Create Queued Resource“ der API erstellen. Als Beispiel stellen wir multihost_job.py bereit, löst den API-Aufruf Created Queued Resource mit entsprechendem Start aus um das Training auszuführen und bei vorzeitigem Beenden fortzusetzen. Das multihost_job.py Skript im README-Datei für MaxText.

Da multihost_job.py für jede Ausführung Ressourcen bereitstellen muss, wird dies nicht einen so schnellen Iterationszyklus wie multihost_runner.py bereitstellen.

Manueller Ansatz

Wir empfehlen, multihost_runner.py zu verwenden oder anzupassen oder multihost_job.py, um Ihre benutzerdefinierte Arbeitslast in Ihre Multislice-Konfiguration. Wenn Sie jedoch eine Bereitstellung und können Sie Ihre Umgebung direkt über QR-Befehle verwalten, siehe Multislice-Umgebung verwalten

Multislice-Umgebung verwalten

So können Sie QR-Codes manuell bereitstellen und verwalten, ohne die Tools zu verwenden im MaxText-Repository bereitgestellt haben, lesen Sie die folgenden Abschnitten.

Kurzantworten erstellen

Legen Sie die folgenden Umgebungsvariablen fest, bevor Sie Kapazität bereitstellen:

  $ export your-qr-id=your-queued-resource-id
  $ export PROJECT=your-project-name
  $ export ZONE=us-central2-b
  $ export NETWORK_NAME=your-network-name
  $ export SUBNETWORK_NAME=your-subnetwork-name
  $ export RUNTIME_VERSION=tpu-ubuntu2204-base
  $ export ACCELERATOR_TYPE=v4-16
  $ export SLICE_COUNT=4
  $ export STARTUP_SCRIPT="#!/bin/bash\n ..."
  $ gcloud config set project project-name
  $ gcloud config set compute/zone zone
Eingabe Beschreibung
your-qr-id Die vom Nutzer zugewiesene ID des QR-Codes.
PROJECT Name des Google Cloud-Projekts
ZONE us-central2-b
NETWORK_NAME Name der VPC-Netzwerke.
SUBNETWORK_NAME Name des Subnetzes in VPC-Netzwerken
RUNTIME_VERSION tpu-Ubuntu2204-base
ACCELERATOR_TYPE v4-16
EXAMPLE_TAG_1, EXAMPLE_TAG_2... Tags zum Identifizieren gültiger Quellen oder Ziele für Netzwerkfirewalls
SLICE_COUNT Anzahl der Segmente. Begrenzt auf maximal 256 Segmente.
STARTUP_SCRIPT Falls zur Erstellungsanfrage hinzugefügt: Ein Startskript kann immer dann ausgeführt werden, wenn ein TPU-Slice bereitgestellt oder neu gestartet wird. und ob das TPU-Slice repariert oder zurückgesetzt wurde.

QR-Anfrage mit gcloud erstellen

$ gcloud compute tpus queued-resources \
  create ${your-qr-id} \
  --project your-project-id \
  --zone your-zone \
  --node-count ${SLICE_COUNT} \
  --accelerator-type ${ACCELERATOR_TYPE} \
  --runtime-version ${RUNTIME_VERSION} \
  --network ${NETWORK_NAME} \
  --subnetwork ${SUBNETWORK_NAME} \
  --tags ${EXAMPLE_TAG_1},${EXAMPLE_TAG_2} \ --metadata=startup-script='${STARTUP_SCRIPT}'
  [--reserved|--spot]
  

Beschreibung der Befehls-Flags

your-qr-id
Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
project
Ein benutzerdefinierter String, der die QR-Anfrage identifiziert.
zone
Die Google Cloud-Zone, in der der QR-Code erstellt werden soll.
node-count
Die Anzahl der zu erstellenden Segmente.
accelerator-type
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen.
runtime-version
Die Cloud TPU-Softwareversion.
network
Der Name eines VPC-Netzwerk, an das die TPU-Ressource angehängt werden soll.
subnetwork
Der Name eines VPC-Subnetzwerks, an das die TPU-Ressource angehängt werden soll.
reserved
Verwenden Sie beim Erstellen der Slices ein reserviertes Kontingent.
spot
Verwenden Sie beim Erstellen der Slices das Kontingent für Spot-VMs.

Prüfen Sie, ob Sie das entsprechende Kontingent haben, bevor Sie --reserved auswählen. --spot oder das standardmäßige On-Demand-Kontingent. Informationen zu Kontingenttypen Siehe Kontingentrichtlinie.

QR-Anfrage mit curl erstellen

Erstellen Sie eine Datei mit dem Namen queued-resource-req.json und kopieren Sie den folgenden JSON-Code hinein.

{
  "guaranteed": { "reserved": true },
  "tpu": {
    "node_spec": [
    {
      "parent": "projects/your-project-number/locations/your-zone",
        "node": {
          "accelerator_type": "accelerator-type",
          "runtime_version": "tpu-vm-runtime-version",
          "network_config": {
            "network": "your-network-name",
            "subnetwork": "your-subnetwork-name",
            "enable_external_ips": true
          },
          "tags" : ["example-tag-1"]
          "metadata": {
            "startup-script": "your-startup-script"
          }
      },
      "multi_node_params": {
        "node_count": slice-count,
        "node_id_prefix": "your-queued-resource-id"
      }
    }
    ]
  }
}
  • your-project-number – Ihre Google Cloud-Projektnummer
  • your-zone – Der Bereich, in dem Sie den QR-Code erstellen möchten
  • accelerator-type – Version und Größe eines einzelnen Slice
  • tpu-vm-runtime-version: Die TPU-VM-Laufzeitversionen
  • your-network-name – Optional ein Netzwerk, an das der QR-Code angehängt wird
  • your-subnetwork-name: Optional ein Subnetzwerk, an das der QR-Code angehängt wird
  • example-tag-1 (optional) ein beliebiger Tag-String
  • your-startup-script – ein Startskript, das ausgeführt wird, wenn der QR-Code zugewiesen wird
  • slice-count – Die Anzahl der TPU-Slices in Ihrer Multislice-Umgebung
  • your-qr-id – Die vom Nutzer angegebene ID für den QR-Code

Weitere Informationen finden Sie unter REST Queued Resource API Dokumentation zu allen verfügbaren Optionen.

Ersetzen Sie Folgendes, um die Spot-Kapazität zu verwenden:

"guaranteed": { "reserved": true } mit "spot": {}

Entfernen Sie die Zeile, um die standardmäßige On-Demand-Kapazität zu verwenden.

Senden Sie die Anfrage zum Erstellen des QR-Codes mit der JSON-Nutzlast:

  $ curl -X POST -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" -d @queuedresourcereq.json https://tpu.googleapis.com/v2alpha1/projects/your-project-id/locations/your-zone/queuedResources\?queued_resource_id\=your-qr-id
  • your-project-id – Ihre Google Cloud-Projekt-ID
  • your-zone – Der Bereich, in dem Sie den QR-Code erstellen möchten
  • your-qr-id – Die vom Nutzer angegebene ID für den QR-Code

Die Antwort sollte in etwa so aussehen:

{
  "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qr-guid>",
  "metadata": {
    "@type": "type.googleapis.com/google.cloud.common.OperationMetadata",
    "createTime": "2023-11-01T00:17:05.742546311Z",
    "target": "projects/<your-project-id>/locations/<your-zone>/queuedResources/<your-qa-id>",
    "verb": "create",
    "cancelRequested": false,
    "apiVersion": "v2alpha1"
  },
  "done": false
}

Verwenden Sie den GUID-Wert am Ende des Stringwerts für das Attribut name, um Informationen zur QR-Anfrage.

Status eines QR-Codes abrufen

Verwenden Sie den folgenden Befehl, um den Status der QR-Anfrage abzurufen:

  $ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/your-project-id/locations/your-zone/operations/operation-your-qr-guid
  • your-project-id – Ihre Google Cloud-Projekt-ID
  • your-zone – Der Bereich, in dem der QR-Code erstellt werden soll.
  • your-qr-guid – die GUID nach name in der Ausgabe des Anfrage zur Erstellung eines QR-Codes.

Die Antwort auf diesen Befehl enthält den Status des Vorgangs:

{
  "name": "projects/<your-project-id>/locations/<your-zone>/operations/operation-<your-qa-guid>,
  "metadata": {...},
  "done": true,
  "response": {
    "@type": "type.googleapis.com/google.cloud.tpu.v2.QueuedResource",
    ...
    "state": {
      "state": "WAITING_FOR_RESOURCES"
    }
  }
}

Wenn der QR-Code erfolgreich erstellt wurde, ("done = true"), wird der Status innerhalb der Das Feld response ist entweder WAITING_FOR_RESOURCES oder FAILED. Wenn der QR-Code den Status WAITING_FOR_RESOURCES hat, wurde er und mit der Bereitstellung beginnen, sobald genügend Ressourcen vorhanden sind. Wenn der QR-Code den Status FAILED hat, wird die Fehlerursache in der Ausgabe angegeben. Weitere Informationen Informationen zu anderen möglichen Status finden Sie in der Nutzerhandbuch für Ressourcen in der Warteschlange.

Sobald der Vorgang abgeschlossen ist, verwenden Sie die Kurzantworten beschreiben. um die Phasen des QR-Codes zu überwachen.

In seltenen Fällen kann es vorkommen, dass Ihr QR-Code den Status FAILED hat, während einige Slices ACTIVE. Löschen Sie in diesem Fall die erstellten Ressourcen und versuche es in ein paar Minuten noch einmal oder wende dich an uns an das Cloud TPU-Team senden, um das Problem zu beheben.

SSH-Verbindung herstellen und Abhängigkeiten installieren

JAX-Code auf TPU-Pod-Slices ausführen wird beschrieben, wie Sie in einem einzelnen Slice mithilfe von SSH eine Verbindung zu Ihren TPU-VMs herstellen. Bis eine Verbindung zu allen TPU-VMs in Ihrer Multislice-Umgebung über SSH und um Abhängigkeiten zu installieren, verwenden Sie den folgenden gcloud-Befehl:

  $ gcloud compute tpus queued-resources ssh ${your-qr-id} \
    --zone your-zone \
    --node=all \
    --worker=all \
    --command="command-to-run"
    --batch-size=4

Mit diesem gcloud-Befehl wird der angegebene Befehl an alle Worker und Knoten in QR-Code mit SSH Der Befehl wird in Vierergruppen zusammengefasst und gleichzeitig. Der nächste Batch von Befehlen wird gesendet, wenn der aktuelle Batch um die Ausführung abzuschließen. Tritt bei einem der Befehle ein Fehler auf, wird die Verarbeitung und es werden keine weiteren Batches gesendet. Weitere Informationen finden Sie in der API-Referenz für Ressourcen in der Warteschlange Wenn die Anzahl der verwendeten Slices das Threading des lokalen Computers überschreitet (auch Batch-Limit genannt), besteht ein Deadlock. Hier ein Beispiel: davon ausgehen, dass das Batchlimit auf Ihrem lokalen Rechner bei 64 liegt. Wenn Sie versuchen, auf mehr als 64 Slices, z. B. 100, hat der SSH-Befehl in Batches aufteilen. Das Trainingsskript wird beim ersten Batch von 64 und warten Sie, bis die Skripte abgeschlossen sind, bevor Sie sie auf der verbleibenden Batch von 36 Segmenten. Der erste Batch mit 64 Slices kann jedoch nicht bis die verbleibenden 36 Slices beginnen, das Skript auszuführen, was zu einer Deadlock.

Um dies zu verhindern, können Sie das Trainingsskript im Hintergrund auf jede VM durch Anhängen eines kaufmännischen Und-Zeichens (&) an den von Ihnen angegebenen Skriptbefehl mit dem Flag --command. Wenn Sie dies tun, nachdem Sie das Trainingsskript gestartet haben, auf den ersten Batch von Segmenten, kehrt die Steuerung sofort wieder auf den SSH-Befehl. Der SSH-Befehl kann dann das Trainingsskript auf der restlichen 36 Segmente. Sie müssen stdout und stderr streamt sie ordnungsgemäß, wenn sie im Hintergrund ausgeführt wird. Zum Erhöhen Parallelität innerhalb desselben QR-Codes können Sie bestimmte Segmente über die --node .

Netzwerkeinrichtung

Führen Sie die folgenden Schritte aus, damit TPU-Slices miteinander kommunizieren können. Installieren Sie JAX auf jedem Segment. Weitere Informationen finden Sie unter JAX-Code auf TPU-Pod-Slices ausführen Behaupten, dass len(jax.devices()) entspricht der Anzahl der Chips im Multislice zu verbessern. Führen Sie dazu für jedes Segment folgenden Befehl aus:

  $ python3 -c 'import jax; print(jax.devices())'

Wenn Sie diesen Code auf vier Abschnitten von v4-16 ausführen, gibt es acht Chips pro Slice und vier Slices werden insgesamt 32 Chips (Geräte) zurückgegeben. von jax.devices().

Kurzantworten auflisten

Sie können den Status Ihrer QRs mit dem Befehl queued-resources list abrufen:

$ gcloud compute tpus queued-resources list

NAME        ZONE           NODE_COUNT  ACCELERATOR_TYPE  STATE
...
que-res-id  us-central2-b  4           v4-16             ACTIVE
...

Kurzantworten beschreiben

Um die detaillierte Konfiguration und den Status eines QR-Codes anzuzeigen, verwende die um die QR-API zu beschreiben. Sie können diese API mit gcloud oder curl aufrufen.

mit gcloud:

$ gcloud compute tpus queued-resources describe ${your-qr-id}
...state:
 state: ACTIVE
...

mit curl:

$ curl -X GET -H "Authorization: Bearer $(gcloud auth print-access-token)" -H "Content-Type: application/json" https://tpu.googleapis.com/v2/projects/your-project-id/locations/your-zone/queuedResources/${your-qr-id}
{
  "name": your-queued-res,
  "tpu": {
    "nodeSpec": [
      {
        ... // node 1
      },
      {
        ... // node 2
      },
      ...
    ]
  },
  ...
  "state": "ACTIVE"
}

state steht für den Status eines QR-Codes. Weitere Informationen zu den möglichen Informationen zum Status von QR-Codes finden Sie unter Ressourcen in der Warteschlange.

Job in einer bereitgestellten Umgebung starten

Sie können Arbeitslasten manuell ausführen, indem Sie über SSH eine Verbindung zu allen Hosts in jedem Segment herstellen und führen den folgenden Befehl auf allen Hosts aus.

$ gcloud compute tpus tpu-vm ssh your-qr-id \
  --zone=your-zone \
  --worker=all \
  --node=all \
  --command="command-to-run"

Kurzantworten zurücksetzen

Die ResetQueuedResource API kann zum Zurücksetzen verwendet werden alle VMs in einem ACTIVE-QR-Code. Durch das Zurücksetzen der VMs wird das Löschen des Arbeitsspeichers und die VM in ihren Ausgangszustand zurücksetzt. Alle lokal gespeicherten Daten bleiben intakt und das Startskript wird nach dem Zurücksetzen aufgerufen. Die Die ResetQueuedResource API kann nützlich sein, wenn Sie alle TPUs neu starten möchten. Für z. B. wenn das Training hängen bleibt und das Zurücksetzen aller VMs einfacher ist als das Debugging.

Das Zurücksetzen aller VMs erfolgt parallel und es wird ein ResetQueuedResource dauert der Vorgang ein bis zwei Minuten. Rufen Sie die API mit folgendem Befehl auf: Befehl:

$ gcloud compute tpus queued-resources reset your-qr-id

QR-Codes werden gelöscht

Um am Ende der Trainingssitzung Ressourcen freizugeben, löschen Sie die in der Warteschlange Ressource mit dem Flag --force. Der Löschvorgang dauert zwei bis fünf Minuten, abgeschlossen und kann mit dem optionalen Flag --async im Hintergrund ausgeführt werden.

$ gcloud compute tpus queued-resources \
delete your-qr-id --force (--async)

Automatische Wiederherstellung nach Fehlern

Im Falle einer Unterbrechung bietet Multislice keine Eingriffe. Reparatur des betroffenen Slice und anschließendes Zurücksetzen aller Segmente. Betroffene Slice wird durch ein neues Slice und die restlichen ansonsten fehlerfreien Slices ersetzt. wurden zurückgesetzt. Wenn keine Kapazität zum Zuweisen eines Ersatz-Slice verfügbar, wird das Training beendet.

Wenn Sie das Training nach einer Unterbrechung automatisch fortsetzen möchten, müssen Sie eine Startskript, das auf und lädt die letzten gespeicherten Prüfpunkte. Ihr Startskript wird automatisch ausgeführt jedes Mal, wenn ein Segment neu zugewiesen oder eine VM zurückgesetzt wird. Sie geben ein Start-up an in der JSON-Nutzlast, die Sie an die API zum Erstellen der QR-Anfrage senden.

Das folgende Startskript (wird in Kurzantworten erstellen verwendet) können Sie sich automatisch nach Fehlern erholen und das Training ab dem Prüfpunkte, die während des MaxText-Trainings in einem Cloud Storage-Bucket gespeichert wurden:

{
 "tpu": {
   "node_spec": [
     {
      ...
         "metadata": {
               "startup-script": "#! /bin/bash \n pwd \n runuser -l user1 -c 'cd /home/user1/MaxText && python3 MaxText/train.py MaxText/configs/base.yml run_name=run_test_failure_recovery dcn_data_parallelism=4 ici_fsdp_parallelism=8 steps=10000 save_period=10 base_output_directory='gs://user1-us-central2'' EOF"
         }
     ...
     }
   ]
 }
}

Klonen Sie das MaxText-Repository, bevor Sie dies versuchen aus.

Profilerstellung und Fehlerbehebung

Die Profilerstellung ist in Einzel-Slice- und Multi-Slice-Umgebungen identisch. Für Weitere Informationen finden Sie unter Profilerstellung für JAX-Programme.

Optimierte Schulungen

Fragmentierung mit Multislice für maximale Leistung

Um in Multislice-Umgebungen maximale Leistung zu erzielen, über mehrere Segmente aufteilen. Normalerweise gibt es drei Auswahlmöglichkeiten (Datenparallelität, vollständig fragmentierte Datenparallelität und Pipeline-Parallelität). Wir raten davon ab, Aktivierungen auf die Modelldimensionen zu verteilen (manchmal Tensor-Parallelität), da dafür zu viel Bandbreite zwischen den Segmenten benötigt wird. Für alle diese Strategien kann dieselbe Fragmentierungsstrategie für ein Segment beibehalten werden. die sich bei Ihnen bewährt hat.

Wir empfehlen, mit reiner Datenparallelität zu beginnen. Vollständig fragmentierte Daten verwenden Parallelität ist nützlich, um Arbeitsspeichernutzung freizugeben. Der Nachteil ist, dass Die Kommunikation zwischen den Segmenten nutzt das DCN-Netzwerk und verlangsamt Arbeitsbelastung. Pipeline-Parallelität nur verwenden, wenn dies basierend auf der Batchgröße erforderlich ist (wie unten analysiert).

Wann sollte Datenparallelität verwendet werden?

Die reine Datenparallelität funktioniert gut, wenn Sie eine Arbeitslast haben, aber Sie möchten die Leistung verbessern, indem Sie in mehreren Segmenten.

Um eine starke Skalierung über mehrere Segmente hinweg zu erzielen, für die vollständige Reduzierung im DCN darf der erforderliche Zeitraum nicht unterschreiten. für einen Rückwärtsdurchlauf. DCN wird für die Kommunikation zwischen Slices und ein begrenzender Faktor für den Arbeitslastdurchsatz.

Jeder v4-TPU-Chip hat eine maximale Leistung von 275 × 1012 FLOPS pro Sekunde.

Es gibt vier Chips pro TPU-Host und jeder Host hat eine maximale Netzwerkbandbreite von 50 Gbit/s.

Das bedeutet die arithmetische Intensität ist 4 × 275 × 1012 FLOPS ÷ 50 Gbit / s = 22.000 FLOPS / Bit.

Ihr Modell verwendet 32 bis 64 Bit DCN-Bandbreite für jeden Parameter und Schritt. Wenn Sie zwei Segmente verwenden, benötigt Ihr Modell 32 Bit DCN-Bandbreite. Wenn Sie Wenn Sie mehr als zwei Slices verwenden, führt der Compiler einen vollständigen Shuffle All-Reduce durch. und Sie verwenden bis zu 64 Bit DCN-Bandbreite für jeden Parameter pro Schritt. Wie viele FLOPS für jeden Parameter erforderlich sind, hängt von Modell. Insbesondere bei Transformer-basierten Sprachmodellen ist die Anzahl der FLOPS Erforderlich für eine Vorwärts- und Rückwärtsterminierung sind ungefähr 6 * B * P, wobei Folgendes gilt:

  • B ist die Batchgröße in Tokens.
  • P ist die Anzahl der Parameter.

Die Anzahl der FLOPS pro Parameter ist 6 * B und die Anzahl der FLOPS pro Parameter beim Rückwärtstermin ist 4 * B.

Um eine starke Skalierung in mehreren Segmenten zu gewährleisten, muss der operative Intensität überschreitet die arithmetische Intensität der TPU-Hardware. Um die Betriebsintensität, dividieren Sie die Anzahl der FLOPS pro Parameter während der Netzwerkbandbreite (in Bit) pro Parameter und Schritt rückwärts: Operational Intensity = FLOPSbackwards_pass / DCN bandwidth

Wenn Sie also bei einem Transformer-basierten Sprachmodell zwei Slices verwenden, Operational intensity = 4 * B / 32

Wenn Sie mehr als zwei Segmente verwenden: Operational intensity = 4 * B/64

Dies deutet auf eine Mindest-Batchgröße zwischen 176.000 und 352.000 für Transformer hin Sprachmodellen. Da das DCN Pakete kurzzeitig löschen kann, um eine signifikante Fehlerspanne beizubehalten und nur Datenparallelität bereitzustellen. wenn die Batchgröße pro Pod mindestens 350.000 (zwei Pods) bis 700.000 (viele Pods) beträgt.

Bei anderen Modellarchitekturen müssen Sie die Laufzeit Rückwärtsdurchgang pro Slice (entweder durch die zeitliche Abfolge mit einem Profiler oder durch Zählen) FLOPS). Dann können Sie dies mit der erwarteten Laufzeit vergleichen, und eine gute Einschätzung dessen, ob eine Datenparallelität für Sie sinnvoll ist.

Wann sollte die vollständig fragmentierte Datenparallelität (FSDP) verwendet werden?

Die vollständig fragmentierte Datenparallelität (Fully Sharded Data Parallelism (FSDP)) kombiniert die Datenparallelität (Fragmentierung der Daten über Knoten hinweg), wobei die Gewichtungen auf Knoten aufgeteilt werden. Für jeden Vorgang in Beim Vor- und Zurückgehen sind die Gewichte eingesammelt, sodass jedes die benötigten Gewichte hat. Anstatt die Farbverläufe mithilfe von werden die Gradienten beim Entstehen reduziert. Auf diese Weise erhält jedes Segment nur die Gradienten für die Gewichtungen, für die es verantwortlich ist.

Ähnlich wie bei der Datenparallelität erfordert FSDP eine Skalierung der globalen Batchgröße linear mit der Anzahl der Segmente. FSDP verringert die Speicherauslastung, erhöhen Sie die Anzahl der Segmente. Das liegt daran, dass die Anzahl der Gewichtungen Optimierstatus pro Segment nimmt ab, aber es geschieht zum Preis und erhöht die Wahrscheinlichkeit einer Blockierung aufgrund kollektiv.

In der Praxis ist FSDP segmentübergreifend am besten geeignet, wenn Sie den Batch pro und mehr Aktivierungen speichern, um die Re-Materialisierung während der die Anzahl der Parameter in Ihrem neuronalen Netzwerk rückwärts durchläuft.

Die Operationen „All-Database“ und „All-Reduce“ funktionieren beim FSDP ähnlich wie bei DP, So können Sie feststellen, ob Ihre FSDP-Arbeitslast durch die DCN-Leistung in wie im vorherigen Abschnitt beschrieben.

Wann sollte die Pipeline-Parallelität verwendet werden?

Die Pipeline-Parallelität wird relevant, wenn eine hohe Leistung mit anderen Parallelitätsstrategien, die eine globale Batchgröße erfordern, die größer als bevorzugte maximale Batchgröße. Dank der Pipeline-Parallelität mit einer Pipeline zur Freigabe einen Batch. Die Pipeline-Parallelität hat jedoch zwei wesentliche Nachteile:

  1. Dabei wird das Pipeline-Infofeld angezeigt. Chips sind inaktiv, weil sie warten nach Daten.
  2. Es erfordert Mikro-Batching, wodurch sich die effektive Batchgröße verringert, die arithmetische Intensität und schließlich die FLOP-Auslastung.

Die Pipeline-Parallelität sollte nur verwendet werden, wenn die anderen Parallelitätsstrategien eine zu große globale Batchgröße erfordern. Bevor Sie die Pipeline-Parallelität testen, Es lohnt sich, empirisch zu experimentieren, ob die Konvergenz pro Stichprobe die für ein leistungsstarkes FSDP erforderlich ist. FSDP erreicht tendenziell höhere FLOP-Auslastung des Modells, aber wenn die Konvergenz pro Stichprobe sich verlangsamt, nimmt die Pipeline-Parallelität zu. Meiste können ausreichend große Batchgrößen toleriert werden, Pipeline-Parallelität. Ihre Arbeitslast kann jedoch abweichen.

Wenn eine Pipeline-Parallelität erforderlich ist, empfehlen wir, sie mit Daten zu kombinieren. oder FSDP. So können Sie die Pipelinetiefe minimieren, während dass die Batchgröße pro Pipeline erhöht wird, bis die DCN-Latenz abnimmt. Durchsatz berücksichtigt wird. Konkret sollten Sie bei n Slices Pipelines mit Tiefe 2- und N/2-Replikate der Datenparallelität, dann Pipelines der Tiefe 4 und N/4 Replikate der Datenparallelität usw., bis der Batch pro Pipeline groß wird dass die DCN-Sammlungen hinter der Arithmetik in der rückwärts durch. Dadurch wird die durch die Pipeline verursachte Verlangsamung minimiert. und gleichzeitig eine Skalierung über das globale Limit für Batchgrößen hinaus ermöglicht.

Best Practices für mehrere Segmente

Laden der Daten

Während des Trainings werden wiederholt Batches aus einem Dataset geladen, um Daten in den Modell. Ein effizientes asynchrones Datenladeprogramm, das den Batch auf ist wichtig, damit die TPUs nicht ausgebremst werden. Das aktuelle Datenladeprogramm in MaxText hat für jede Hostlast eine gleiche Teilmenge der Beispiele. Diese Lösung ist ist für Text geeignet, erfordert aber ein Reshard innerhalb des Modells. Zusätzlich kann MaxText bietet noch keine deterministische Snapshot-Erstellung, die es dem Daten-Iterator ermöglichen würde, um vor und nach dem vorzeitigen Beenden die gleichen Daten zu laden.

Prüfpunktausführung

Die Prüfpunktbibliothek Orbax bietet Primitiven für die Prüfpunktausführung von JAX-PyTrees auf lokalen Speicher oder Google Cloud-Speicher. Wir bieten eine Referenzintegration mit synchroner Prüfpunktausführung in MaxText. in „checkpointing.py“.

Unterstützte Konfigurationen

Formen

Alle Segmente müssen die gleiche Form haben (z. B. dieselbe AcceleratorType). Heterogene Scheibenformen werden nicht unterstützt.

Orchestrierung

Die Orchestrierung wird mit GKE unterstützt. Weitere Informationen Siehe TPUs in GKE.

Frameworks

Multislice unterstützt nur JAX- und PyTorch-Arbeitslasten.

Parallelismus

Wir empfehlen Nutzern, „Multislice“ mit Datenparallelität zu testen. Weitere Informationen zur Implementierung der Pipeline-Parallelität mit Multislice erfahren Sie von Ihrem Google Cloud-Kundenbetreuer

Support und Feedback

Wir freuen uns über jedes Feedback! Wenn du Feedback geben oder Unterstützung anfordern möchtest, wende dich an uns mithilfe des Support- oder Feedbackformulars für Cloud TPU.