Trillium (v6e) – Einführung

In dieser Dokumentation, in der TPU API und in den Protokollen wird v6e für Trillium verwendet. v6e steht für die sechste Generation von TPU von Google.

Mit 256 Chips pro Pod hat v6e viele Ähnlichkeiten mit v5e. Dieses System ist für das Training, die Feinabstimmung und das Bereitstellen von Transformern, Text-zu-Bild-Modellen und CNNs (Convolutional Neural Networks) optimiert.

v6e-Systemarchitektur

Informationen zur Cloud TPU-Konfiguration finden Sie in der v6e-Dokumentation.

In diesem Dokument liegt der Schwerpunkt auf der Einrichtung des Modelltrainings mit den Frameworks JAX, PyTorch oder TensorFlow. Mit jedem Framework können Sie TPUs mithilfe von Ressourcen in der Warteschlange oder der Google Kubernetes Engine (GKE) bereitstellen. Die GKE-Einrichtung kann mit XPK- oder GKE-Befehlen erfolgen.

Google Cloud-Projekt vorbereiten

  1. Melden Sie sich in Ihrem Google-Konto an. Wenn Sie noch kein Konto haben, melden Sie sich hier für ein neues Konto an.
  2. Wählen Sie in der Google Cloud Console auf der Seite der Projektauswahl ein Cloud-Projekt aus oder erstellen Sie eines.
  3. Aktivieren Sie die Abrechnung für Ihr Google Cloud-Projekt. Für die gesamte Nutzung von Google Cloud ist eine Abrechnung erforderlich.
  4. Installieren Sie die gcloud-Alphakomponenten.
  5. Führen Sie den folgenden Befehl aus, um die neueste Version der gcloud-Komponenten zu installieren.

    gcloud components update
    
  6. Aktivieren Sie die TPU API mit dem folgenden gcloud-Befehl in Cloud Shell. Sie können sie auch über die Google Cloud Console aktivieren.

    gcloud services enable tpu.googleapis.com
    
  7. Berechtigungen mit dem TPU-Dienstkonto für die Compute Engine API aktivieren

    Dienstkonten ermöglichen dem Cloud TPU-Dienst, auf andere Google Cloud-Dienste zuzugreifen. Ein nutzerverwaltetes Dienstkonto ist eine empfohlene Google Cloud-Praxis. Folgen Sie diesen Anleitungen, um Rollen zu erstellen und zu gewähren. Folgende Rollen sind erforderlich:

    • TPU-Administrator
    • Storage-Administrator
    • Log-Autor
    • Monitoring-Messwert-Autor

    a. Richten Sie XPK-Berechtigungen mit Ihrem Nutzerkonto für GKE ein: XPK.

  8. Erstellen Sie Umgebungsvariablen für die Projekt-ID und die Zone.

     gcloud auth login
     gcloud config set project ${PROJECT_ID}
     gcloud config set compute/zone ${ZONE}
    
  9. Erstellen Sie eine Dienstidentität für die TPU-VM.

     gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    

Sichere Kapazität

Wenden Sie sich an Ihren Cloud TPU-Supportmitarbeiter, um ein TPU-Kontingent anzufordern und Fragen zur Kapazität zu stellen.

Cloud TPU-Umgebung bereitstellen

v6e-TPUs können mit GKE, mit GKE und XPK (einem Befehlszeilen-Wrapper für GKE) oder als in der Warteschlange befindliche Ressourcen bereitgestellt und verwaltet werden.

Vorbereitung

  • Prüfen Sie, ob für Ihr Projekt ein ausreichendes TPUS_PER_TPU_FAMILY-Kontingent vorhanden ist. Dieses gibt die maximale Anzahl der Chips an, auf die Sie in Ihrem Google Cloud-Projekt zugreifen können.
  • v6e wurde mit der folgenden Konfiguration getestet:
    • Python 3.10 oder höher
    • Nightly-Softwareversionen:
      • JAX pro Nacht 0.4.32.dev20240912
      • nightly LibTPU 0.1.dev20240912+nightly
    • Stabile Softwareversionen:
      • JAX + JAX-Bibliothek der Version 0.4.35
  • Prüfen Sie, ob Ihr Projekt ein ausreichendes TPU-Kontingent für Folgendes hat:
    • TPU-VM-Kontingent
    • Kontingent für IP-Adressen
    • Hyperdisk-Balance-Kontingent
  • Nutzerberechtigungen für Projekte

Umgebungsvariablen

Erstellen Sie in Cloud Shell die folgenden Umgebungsvariablen:

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-central2-b
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the
# default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Beschreibung der Befehls-Flags

Variable Beschreibung
NODE_ID Die vom Nutzer zugewiesene ID der TPU, die erstellt wird, wenn die anstehende Ressourcenanfrage zugewiesen wird.
PROJECT_ID Name des Google Cloud-Projekts. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter
ZONE Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen.
ACCELERATOR_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Das ist die E-Mail-Adresse Ihres Dienstkontos. Sie finden sie in der Google Cloud Console unter „IAM“ -> „Dienstkonten“.

Beispiel: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

NUM_SLICES Die Anzahl der zu erstellenden Scheiben (nur für Mehrfachaufnahmen erforderlich)
QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der anstehenden Ressourcenanfrage.
VALID_DURATION Die Dauer, für die die angeforderte Ressource gültig ist.
NETWORK_NAME Der Name eines sekundären Netzwerks, das verwendet werden soll.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

Optimierungen der Netzwerkleistung

Für die beste Leistung verwenden Sie ein Netzwerk mit 8.896 MTU (maximale Übertragungseinheit).

Standardmäßig bietet eine Virtual Private Cloud (VPC) nur eine MTU von 1.460 Byte, was zu einer suboptimalen Netzwerkleistung führt. Sie können die MTU eines VPC-Netzwerk auf einen beliebigen Wert zwischen 1.300 Byte und 8.896 Byte (einschließlich) festlegen. Gängige benutzerdefinierte MTU-Größen sind 1.500 Byte (Ethernet-Standard) oder 8.896 Byte (maximal möglich). Weitere Informationen finden Sie unter Gültige MTU-VPC-Netzwerk-Netzwerke.

Weitere Informationen zum Ändern der MTU-Einstellung für ein vorhandenes oder Standardnetzwerk finden Sie unter MTU-Einstellung eines VPC-Netzwerks ändern.

Im folgenden Beispiel wird ein Netzwerk mit 8.896 MTU erstellt.

export RESOURCE_NAME=RESOURCE_NAME
export NETWORK_NAME=${RESOURCE_NAME}
export NETWORK_FW_NAME=${RESOURCE_NAME}
export PROJECT=X
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} \

Mehrere NICs verwenden (Option für Multi-Slice)

Die folgenden Umgebungsvariablen sind für ein sekundäres Subnetz erforderlich, wenn Sie eine Multislice-Umgebung verwenden.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2

Verwenden Sie die folgenden Befehle, um eine benutzerdefinierte IP-Weiterleitung für das Netzwerk und das Subnetz zu erstellen.

gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
   --bgp-routing-mode=regional --subnet-mode=custom --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 --region="${REGION}" \
   --project=$PROJECT

gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project="${PROJECT}"

gcloud compute routers create "${ROUTER_NAME}" \
  --project="${PROJECT}" \
  --network="${NETWORK_NAME_2}" \
  --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
  --router="${ROUTER_NAME}" \
  --region="${REGION}" \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project="${PROJECT}" \
  --enable-logging

Nachdem ein Multi-Netzwerk-Speicherbereich erstellt wurde, können Sie prüfen, ob beide NICs verwendet werden, indem Sie --command ifconfig als Teil der XPK-Arbeitslast ausführen. Sehen Sie sich dann die gedruckte Ausgabe dieser XPK-Arbeitslast in den Cloud Console-Protokollen an und prüfen Sie, ob sowohl eth0 als auch eth1 eine mtu=8896 haben.

python3 xpk.py workload create \
   --cluster ${CLUSTER_NAME} \
   (--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}) \
   --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone $ZONE \
   --project $PROJECT_ID \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

Prüfen Sie, ob sowohl eth0 als auch eth1 die mtu=8.896 haben. Sie können prüfen, ob Multi-NIC verwendet wird, indem Sie den Befehl „–command ifconfig“ als Teil der XPK-Arbeitslast ausführen. Sehen Sie sich dann die gedruckte Ausgabe dieser XPK-Arbeitslast in den Cloud Console-Protokollen an und prüfen Sie, ob sowohl eth0 als auch eth1 eine mtu=8896 haben.

Verbesserte TCP-Einstellungen

Für TPUs, die über die Benutzeroberfläche für Ressourcen in der Warteschlange erstellt wurden, können Sie den folgenden Befehl ausführen, um die Netzwerkleistung zu verbessern, indem Sie die Standard-TCP-Einstellungen für rto_min und quickack ändern.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
   --project "$PROJECT" --zone "${ZONE}" \
   --command='ip route show | while IFS= read -r route; do if ! echo $route | \
   grep -q linkdown; then sudo ip route change ${route/lock/} rto_min 5ms quickack 1; fi; done' \
   --worker=all

Bereitstellung mit Ressourcen in der Warteschlange (Cloud TPU API)

Die Kapazität kann mit dem Befehl create für Ressourcen in der Warteschlange bereitgestellt werden.

  1. Erstellen Sie eine Anfrage für eine in die Warteschlange gestellte TPU-Ressource.

    Das --reserved-Flag ist nur für reservierte Ressourcen erforderlich, nicht für On-Demand-Ressourcen.

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --valid-until-duration ${VALID_DURATION} \
      --service-account ${SERVICE_ACCOUNT} \
      [--reserved]

    Wenn die Anfrage für die Ressourcenwarteschlange erfolgreich erstellt wurde, hat das Feld „response“ entweder den Status „WAITING_FOR_RESOURCES“ oder „FAILED“. Wenn sich die angeforderte Ressource in der Warteschlange im Status „WAITING_FOR_RESOURCES“ (Warten auf Ressourcen) befindet, wurde sie in die Warteschlange gestellt und wird bereitgestellt, sobald genügend TPU-Kapazität verfügbar ist. Wenn die Anfrage für die Ressourcenwarteschlange den Status „FAILED“ (Fehlgeschlagen) hat, wird der Grund für den Fehler in der Ausgabe angezeigt. Die anstehende Ressourcenanfrage läuft ab, wenn innerhalb der angegebenen Dauer kein v6e bereitgestellt wird. Der Status ändert sich dann in „FAILED“. Weitere Informationen finden Sie in der öffentlichen Dokumentation zu Ressourcen in der Warteschlange.

    Wenn Ihre in die Warteschlange gestellte Ressourcenanfrage den Status „AKTIV“ hat, können Sie über SSH eine Verbindung zu Ihren TPU-VMs herstellen. Verwenden Sie die Befehle list oder describe, um den Status der in der Warteschlange befindlichen Ressource abzufragen.

    gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project ${PROJECT_ID} --zone ${ZONE}
    

    Wenn sich die erwartete Ressource im Status „AKTIV“ befindet, sieht die Ausgabe in etwa so aus:

      state:
       state: ACTIVE
    
  2. TPU-VMs verwalten Informationen zu den Optionen zum Verwalten Ihrer TPU-VMs finden Sie unter TPU-VMs verwalten.

  3. Über SSH eine Verbindung zu TPU-VMs herstellen

    Sie können Binärdateien auf jeder TPU-VM in Ihrem TPU-Speicherplatz installieren und Code ausführen. Im Abschnitt VM-Typen erfahren Sie, wie viele VMs Ihr Slice haben wird.

    Wenn Sie die Binärdateien installieren oder Code ausführen möchten, können Sie mit dem Befehl tpu-vm ssh eine SSH-Verbindung zu einer VM herstellen.

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --node=all # add this flag if you are using Multislice
    

    Wenn Sie über SSH eine Verbindung zu einer bestimmten VM herstellen möchten, verwenden Sie das Flag --worker, das einem Index mit dem Wert 0 folgt:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
    

    Wenn Sie Slice-Formen mit mehr als 8 Chips haben, befinden sich mehrere VMs in einem Slice. Verwenden Sie in diesem Fall die Parameter --worker=all und --command in Ihrem gcloud alpha compute tpus tpu-vm ssh-Befehl, um einen Befehl gleichzeitig auf allen VMs auszuführen. Beispiel:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
      --zone  ${ZONE} --worker=all \
      --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. In die Warteschlange gestellte Ressource löschen

    Lösche am Ende der Sitzung eine in die Warteschlange gestellte Ressource oder entferne in die Warteschlange gestellte Ressourcenanfragen, die den Status „FAILED“ (Fehlgeschlagen) haben. Wenn Sie eine in die Warteschlange gestellte Ressource löschen möchten, löschen Sie in zwei Schritten den Ausschnitt und dann die Anfrage für die in die Warteschlange gestellte Ressource:

    gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \
     --zone=${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
     --project ${PROJECT_ID} --zone ${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
    

GKE mit v6e verwenden

Wenn Sie GKE-Befehle mit v6e verwenden, können Sie Kubernetes-Befehle oder XPK verwenden, um TPUs bereitzustellen und Modelle zu trainieren oder bereitzustellen. Unter TPUs in GKE planen erfahren Sie, wie Sie GKE mit TPUs und v6e verwenden.

Framework einrichten

In diesem Abschnitt wird die allgemeine Einrichtung für das Training von ML-Modellen mit den Frameworks JAX, PyTorch oder TensorFlow beschrieben. Sie können TPUs mit Ressourcen in der Warteschlange oder mit GKE bereitstellen. Die GKE-Einrichtung kann mit XPK- oder Kubernetes-Befehlen erfolgen.

JAX mit Ressourcen in der Warteschlange einrichten

Mit gcloud alpha compute tpus tpu-vm ssh können Sie JAX gleichzeitig auf allen TPU-VMs in Ihrem Slice oder Ihren Slices installieren. Fügen Sie für Multislice --node=all hinzu.


gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'

Sie können den folgenden Python-Code ausführen, um zu prüfen, wie viele TPU-Kerne in Ihrem Snippet verfügbar sind und ob alles richtig installiert ist. Die hier gezeigten Ausgaben wurden mit einem v6e-16-Snippet erstellt:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

Die Ausgabe sieht in etwa so aus:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() gibt die Gesamtzahl der Chips im angegebenen Slice an. jax.local_device_count() gibt die Anzahl der Chips an, auf die eine einzelne VM in diesem Slice zugreifen kann.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 &&
   && pip install -r requirements.txt  && pip install . '

Fehlerbehebung bei JAX-Einrichtungen

Als allgemeinen Tipp können wir Ihnen empfehlen, im Manifest Ihrer GKE-Arbeitslast ausführliches Logging zu aktivieren. Reichen Sie die Protokolle dann beim GKE-Support ein.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Fehlermeldungen

no endpoints available for service 'jobset-webhook-service'

Dieser Fehler bedeutet, dass der Jobsatz nicht richtig installiert wurde. Prüfen Sie, ob die Kubernetes-Pods für die Bereitstellung von „jobset-controller-manager“ ausgeführt werden. Weitere Informationen finden Sie in der Dokumentation zur Fehlerbehebung bei JobSets.

TPU initialization failed: Failed to connect

Die GKE-Knotenversion muss 1.30.4-gke.1348000 oder höher sein. GKE 1.31 wird nicht unterstützt.

Einrichtung für PyTorch

In diesem Abschnitt wird beschrieben, wie Sie PJRT in v6e mit PyTorch/XLA verwenden. Python 3.10 ist die empfohlene Python-Version.

PyTorch mit GKE und XPK einrichten

Sie können den folgenden Docker-Container mit XPK verwenden, in dem die PyTorch-Abhängigkeiten bereits installiert sind:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Führen Sie den folgenden Befehl aus, um eine XPK-Arbeitslast zu erstellen:

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -$NUM_SLICES \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES}  \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

Mit --base-docker-image wird ein neues Docker-Image erstellt, in das das aktuelle Arbeitsverzeichnis eingebunden ist.

PyTorch mit Ressourcen in der Warteschlange einrichten

Führen Sie die folgenden Schritte aus, um PyTorch mit anstehenden Ressourcen zu installieren und ein kleines Script auf v6e auszuführen.

Installieren Sie die Abhängigkeiten über SSH, um auf die VMs zuzugreifen.

Fügen Sie für Multislice --node=all hinzu:

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Leistung von Modellen mit großen, häufigen Zuweisungen verbessern

Bei Modellen mit großen, häufigen Zuweisungen haben wir festgestellt, dass die Verwendung von tcmalloc die Leistung im Vergleich zur Standardimplementierung von malloc erheblich verbessert. Daher ist tcmalloc die Standard-malloc auf der TPU-VM. Je nach Arbeitslast kann tcmalloc jedoch zu einer Verlangsamung führen, z. B. bei DLRM mit sehr großen Zuweisungen für die Einbettungstabellen. In diesem Fall können Sie versuchen, die folgende Variable zurückzusetzen und stattdessen die Standardeinstellung malloc zu verwenden:

unset LD_PRELOAD

Mit einem Python-Script eine Berechnung auf einer v6e-VM ausführen:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

Dadurch wird eine Ausgabe generiert, die etwa so aussieht:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

Einrichtung für TensorFlow

Für die öffentliche Vorschau von v6e wird nur die Laufzeitversion „tf-nightly“ unterstützt.

Sie können tpu-runtime mit der mit v6e kompatiblen TensorFlow-Version zurücksetzen, indem Sie die folgenden Befehle ausführen:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Verwenden Sie SSH, um auf worker-0 zuzugreifen:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

Installieren Sie TensorFlow auf worker-0:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Exportieren Sie die Umgebungsvariable TPU_NAME:

export TPU_NAME=v6e-16

Mit dem folgenden Python-Script können Sie prüfen, wie viele TPU-Kerne in Ihrem Slice verfügbar sind, und testen, ob alles richtig installiert ist. Die angezeigten Ausgaben wurden mit einem v6e-16-Speicherplatz erstellt:

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

Die Ausgabe sieht in etwa so aus:

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e mit SkyPilot

Sie können TPU v6e mit SkyPilot verwenden. Führen Sie die folgenden Schritte aus, um SkyPilot v6e-bezogene Standort-/Preisinformationen hinzuzufügen.

  1. Fügen Sie am Ende von ~/.sky/catalogs/v5/gcp/vms.csv Folgendes hinzu :

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. Geben Sie die folgenden Ressourcen in einer YAML-Datei an:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. Cluster mit TPU v6e starten:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. Stellen Sie eine SSH-Verbindung zur TPU v6e her: ssh tpu_v6

Anleitungen für Inferenz

In den folgenden Abschnitten finden Sie Anleitungen zum Bereitstellen von MaxText- und PyTorch-Modellen mit JetStream sowie zum Bereitstellen von MaxDiffusion-Modellen auf TPU v6e.

MaxText auf JetStream

In dieser Anleitung erfahren Sie, wie Sie mit JetStream MaxText-Modelle (JAX) auf TPU v6e bereitstellen. JetStream ist eine durchsatz- und speicheroptimierte Engine für die LLM-Inferenz (Large Language Model) auf XLA-Geräten (TPUs). In dieser Anleitung führen Sie den Inferenz-Benchmark für das Llama2-7B-Modell aus.

Hinweise

  1. TPU v6e mit 4 Chips erstellen:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Stellen Sie eine SSH-Verbindung zur TPU her:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Anleitung ausführen

Folgen Sie der Anleitung im GitHub-Repository, um JetStream und MaxText einzurichten, die Modell-Checkpunkte zu konvertieren und den Inferenz-Benchmark auszuführen.

Bereinigen

Löschen Sie die TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

vLLM auf PyTorch TPU

Unten finden Sie eine einfache Anleitung für den Einstieg in vLLM auf einer TPU-VM. In den nächsten Tagen veröffentlichen wir einen GKE-Leitfaden mit Best Practices für die Bereitstellung von vLLM auf Trillium in der Produktion. Mehr dazu demnächst!

Hinweise

  1. TPU v6e mit 4 Chips erstellen:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
       --node-id TPU_NAME \
       --project PROJECT_ID \
       --zone ZONE \
       --accelerator-type v6e-4 \
       --runtime-version v2-alpha-tpuv6e \
       --service-account SERVICE_ACCOUNT

    Beschreibung der Befehls-Flags

    Variable Beschreibung
    NODE_ID Die vom Nutzer zugewiesene ID der TPU, die erstellt wird, wenn die anstehende Ressourcenanfrage zugewiesen wird.
    PROJECT_ID Name des Google Cloud-Projekts. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter
    ZONE Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen.
    ACCELERATOR_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
    RUNTIME_VERSION v2-alpha-tpuv6e
    SERVICE_ACCOUNT Das ist die E-Mail-Adresse Ihres Dienstkontos. Sie finden sie in der Google Cloud Console unter „IAM“ -> „Dienstkonten“.

    Beispiel: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

  2. Stellen Sie eine SSH-Verbindung zur TPU her:

    gcloud compute tpus tpu-vm ssh TPU_NAME
    

Create a Conda environment

  1. (Recommended) Create a new conda environment for vLLM:

    conda create -n vllm python=3.10 -y
    conda activate vllm

vLLM auf TPU einrichten

  1. Klonen Sie das vLLM-Repository und wechseln Sie zum vLLM-Verzeichnis:

    git clone https://github.com/vllm-project/vllm.git && cd vllm
    
  2. Bereinigen Sie die vorhandenen Pakete „torch“ und „torch-xla“:

    pip uninstall torch torch-xla -y
    
  3. Installieren Sie PyTorch und PyTorch XLA:

    pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
    
  4. Installieren Sie JAX und Pallas:

    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    
    
  5. Installieren Sie andere Build-Abhängigkeiten:

    pip install -r requirements-tpu.txt
    VLLM_TARGET_DEVICE="tpu" python setup.py develop
    sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
    

Zugriff auf das Modell erhalten

Sie müssen die Einwilligungsvereinbarung unterzeichnen, um die Llama3-Modellfamilie im HuggingFace-Repository verwenden zu können.

Generieren Sie ein neues Hugging Face-Token, falls Sie noch keines haben:

  1. Klicken Sie auf Profil > Einstellungen > Zugriffstokens.
  2. Wählen Sie Neues Token aus.
  3. Geben Sie einen Namen Ihrer Wahl und eine Rolle von mindestens Read an.
  4. Wählen Sie Token generieren aus.
  5. Kopieren Sie das generierte Token in die Zwischenablage, legen Sie es als Umgebungsvariable fest und authentifizieren Sie sich mit der huggingface-cli:

    export TOKEN=''
    git config --global credential.helper store
    huggingface-cli login --token $TOKEN

Benchmarking-Daten herunterladen

  1. Erstellen Sie das Verzeichnis „/data“ und laden Sie das ShareGPT-Dataset von Hugging Face herunter.

    mkdir ~/data && cd ~/data
    wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
    

vLLM-Server starten

Mit dem folgenden Befehl werden die Modellgewichte aus dem Hugging Face Model Hub in das Verzeichnis „/tmp“ der TPU-VM heruntergeladen, eine Reihe von Eingabeformen vorab kompiliert und die Modellkompilierung in ~/.cache/vllm/xla_cache geschrieben.

Weitere Informationen finden Sie in der vLLM-Dokumentation.

   cd ~/vllm
   vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &

vLLM-Benchmarks ausführen

Führen Sie das vLLM-Benchmarking-Script aus:

   python benchmarks/benchmark_serving.py \
       --backend vllm \
       --model "meta-llama/Meta-Llama-3.1-8B"  \
       --dataset-name sharegpt \
       --dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json  \
       --num-prompts 1000

Bereinigen

Löschen Sie die TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

PyTorch auf JetStream

In dieser Anleitung erfahren Sie, wie Sie mit JetStream PyTorch-Modelle auf TPU v6e bereitstellen. JetStream ist eine durchsatz- und speicheroptimierte Engine für die LLM-Inferenz (Large Language Model) auf XLA-Geräten (TPUs). In dieser Anleitung führen Sie den Inferenz-Benchmark für das Llama2-7B-Modell aus.

Hinweise

  1. TPU v6e mit 4 Chips erstellen:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Stellen Sie eine SSH-Verbindung zur TPU her:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Anleitung ausführen

Folgen Sie der Anleitung im GitHub-Repository, um JetStream-PyTorch einzurichten, die Modell-Checkpunkte zu konvertieren und den Inferenz-Benchmark auszuführen.

Bereinigen

Löschen Sie die TPU:

   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --force \
      --async

MaxDiffusion-Inferenz

In dieser Anleitung wird beschrieben, wie Sie MaxDiffusion-Modelle auf TPU v6e bereitstellen. In dieser Anleitung generieren Sie Bilder mit dem Stable Diffusion XL-Modell.

Hinweise

  1. TPU v6e mit 4 Chips erstellen:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Stellen Sie eine SSH-Verbindung zur TPU her:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Conda-Umgebung erstellen

  1. Erstellen Sie ein Verzeichnis für Miniconda:

    mkdir -p ~/miniconda3
  2. Laden Sie das Miniconda-Installationsskript herunter:

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
  3. Installieren Sie Miniconda:

    bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
  4. Entfernen Sie das Miniconda-Installationsskript:

    rm -rf ~/miniconda3/miniconda.sh
  5. Fügen Sie Miniconda der Variablen PATH hinzu:

    export PATH="$HOME/miniconda3/bin:$PATH"
  6. Laden Sie ~/.bashrc neu, um die Änderungen auf die Variable PATH anzuwenden:

    source ~/.bashrc
  7. Erstellen Sie eine neue Conda-Umgebung:

    conda create -n tpu python=3.10
  8. Aktivieren Sie die Conda-Umgebung:

    source activate tpu

MaxDiffusion einrichten

  1. Klonen Sie das MaxDiffusion-Repository und wechseln Sie zum MaxDiffusion-Verzeichnis:

    https://github.com/google/maxdiffusion.git && cd maxdiffusion
  2. Wechseln Sie zum mlperf-4.1-Zweig:

    git checkout mlperf4.1
  3. Installieren Sie MaxDiffusion:

    pip install -e .
  4. Installieren Sie die Abhängigkeiten:

    pip install -r requirements.txt
  5. JAX installieren:

    pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Bilder erstellen

  1. Legen Sie Umgebungsvariablen fest, um die TPU-Laufzeit zu konfigurieren:

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. Bilder mit dem Prompt und den Konfigurationen generieren, die in src/maxdiffusion/configs/base_xl.yml definiert sind:

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"

Bereinigen

Löschen Sie die TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

Trainingsanleitungen

In den folgenden Abschnitten finden Sie Anleitungen zum Trainieren von MaxText.

MaxDiffusion- und PyTorch-Modelle auf TPU v6e

MaxText und MaxDiffusion

In den folgenden Abschnitten wird der Trainingszyklus der Modelle MaxText und MaxDiffusion beschrieben.

Im Allgemeinen sind dies die allgemeinen Schritte:

  1. Erstellen Sie das Basis-Image der Arbeitslast.
  2. Führen Sie die Arbeitslast mit XPK aus.
    1. Erstellen Sie den Trainingsbefehl für die Arbeitslast.
    2. Stellen Sie die Arbeitslast bereit.
  3. Arbeitslast verfolgen und Messwerte ansehen
  4. Löschen Sie die XPK-Arbeitslast, wenn sie nicht benötigt wird.
  5. Löschen Sie den XPK-Cluster, wenn er nicht mehr benötigt wird.

Basis-Image erstellen

Installieren Sie MaxText oder MaxDiffusion und erstellen Sie das Docker-Image:

  1. Klonen Sie das gewünschte Repository und wechseln Sie zum Verzeichnis des Repositorys:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Konfigurieren Sie Docker so, dass die Google Cloud CLI verwendet wird:

    gcloud auth configure-docker
    
  3. Erstellen Sie das Docker-Image mit dem folgenden Befehl oder mit JAX Stable Stack. Weitere Informationen zum JAX Stable Stack finden Sie unter Docker-Image mit JAX Stable Stack erstellen.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. Wenn Sie die Arbeitslast von einem Computer aus starten, auf dem das Image nicht lokal erstellt wurde, laden Sie das Image hoch:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
Docker-Image mit JAX Stable Stack erstellen

Sie können die Docker-Images „MaxText“ und „MaxDiffusion“ mit dem Basis-Image „JAX Stable Stack“ erstellen.

Der JAX Stable Stack bietet eine einheitliche Umgebung für MaxText und MaxDiffusion, da JAX mit Kernpaketen wie orbax, flax und optax sowie einer gut qualifizierten libtpu.so kombiniert wird, die TPU-Programm-Dienstprogramme und andere wichtige Tools antreibt. Diese Bibliotheken werden auf Kompatibilität getestet, um eine stabile Grundlage für das Erstellen und Ausführen von MaxText und MaxDiffusion zu schaffen und potenzielle Konflikte aufgrund von inkompatiblen Paketversionen zu vermeiden.

Der stabile JAX-Stack enthält eine vollständig veröffentlichte und qualifizierte libtpu.so, die Kernbibliothek, die die Kompilierung, Ausführung und ICI-Netzwerkkonfiguration von TPU-Programmen steuert. Der libtpu-Release ersetzt den bisher von JAX verwendeten Nightly-Build und sorgt mit Qualifikationstests auf PJRT-Ebene in HLO/StableHLO-IRs für eine konsistente Funktionsweise von XLA-Berechnungen auf TPUs.

Wenn Sie das Docker-Image „MaxText“ und „MaxDiffusion“ mit dem JAX Stable Stack erstellen möchten, legen Sie beim Ausführen des docker_build_dependency_image.sh-Scripts die Variable MODE auf stable_stack und die Variable BASEIMAGE auf das gewünschte Basis-Image fest.

Im folgenden Beispiel wird us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 als Basisbild angegeben:

bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

Eine Liste der verfügbaren JAX Stable Stack-Basis-Images finden Sie unter JAX Stable Stack-Images in der Artifact Registry.

Arbeitslast mit XPK ausführen

  1. Legen Sie die folgenden Umgebungsvariablen fest, wenn Sie nicht die Standardwerte verwenden, die mit MaxText oder MaxDiffusion festgelegt wurden:

    BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    PER_DEVICE_BATCH_SIZE=2
    NUM_STEPS=30
    MAX_TARGET_LENGTH=8192
  2. Erstellen Sie Ihr Modellskript, das im nächsten Schritt als Trainingsbefehl kopiert wird. Führen Sie das Modellskript noch nicht aus.

    MaxText

    MaxText ist ein leistungsstarker, hoch skalierbarer Open-Source-LLM, der in reiner Python- und JAX-Programmierung geschrieben wurde und für das Training und die Inferenz auf Google Cloud TPUs und GPUs ausgerichtet ist.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=$BASE_OUTPUT_DIR \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}"  # attention='dot_product'"
    

    Gemma2

    Gemma ist eine Familie von Large Language Models (LLMs) mit offenen Gewichten, die von Google DeepMind auf der Grundlage der Gemini-Forschung und -Technologie entwickelt wurden.

    # Requires v6e-256
    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral ist ein hochmodernes KI-Modell, das von Mistral AI entwickelt wurde und eine sparse MoE-Architektur (Mixture of Experts) verwendet.

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama ist eine Familie offener Large Language Models (LLMs), die von Meta entwickelt wurden.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash"
    

    MaxDiffusion

    MaxDiffusion ist eine Sammlung von Referenzimplementierungen verschiedener latenter Diffusionsmodelle, die in reiner Python- und JAX-Programmierung geschrieben wurden und auf XLA-Geräten wie Cloud TPUs und GPUs ausgeführt werden. Stable Diffusion ist ein latentes Text-zu-Bild-Modell, das fotorealistische Bilder aus beliebigen Texteingaben generiert.

    Sie müssen einen bestimmten Branch installieren, um MaxDiffusion auszuführen:

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    Trainingsskript:

        cd maxdiffusion && OUT_DIR=${your_own_bucket}
        python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml \
            run_name=v6e-sd2 \
            split_head_dim=True \
            attention=flash \
            train_new_unet=false \
            norm_num_groups=16 \
            output_dir=${BASE_OUTPUT_DIR} \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            [dcn_data_parallelism=2] \
            enable_profiler=True \
            skip_first_n_steps_for_profiler=95 \
            max_train_steps=${NUM_STEPS} ]
            write_metrics=True'
        
  3. Führen Sie das Modell mit dem Skript aus, das Sie im vorherigen Schritt erstellt haben. Sie müssen entweder das Flag --base-docker-image angeben, um das MaxText-Basis-Image zu verwenden, oder das Flag --docker-image und das gewünschte Image.

    Optional: Sie können das Debug-Logging aktivieren, indem Sie das Flag --enable-debug-logs einfügen. Weitere Informationen finden Sie unter JAX bei MaxText debuggen.

    Optional: Sie können einen Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen. Fügen Sie dazu das Flag --use-vertex-tensorboard hinzu. Weitere Informationen finden Sie unter JAX auf MaxText mit Vertex AI überwachen.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload ${USER}-xpk-ACCELERATOR_TYPE-NUM_SLICES \
        --tpu-type=ACCELERATOR_TYPE \
        --num-slices=NUM_SLICES  \
        --on-demand \
        --zone $ZONE \
        --project $PROJECT_ID \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command YOUR_MODEL_SCRIPT

    Ersetzen Sie die folgenden Variablen:

    • CLUSTER_NAME: Der Name Ihres XPK-Clusters.
    • ACCELERATOR_TYPE: Die Version und Größe Ihrer TPU. Beispiel: v6e-256.
    • NUM_SLICES: Die Anzahl der TPU-Slices.
    • YOUR_MODEL_SCRIPT: Das Modellskript, das als Trainingsbefehl ausgeführt werden soll.

    Die Ausgabe enthält einen Link, über den Sie Ihre Arbeitslast verfolgen können, ähnlich dem folgenden:

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    Öffnen Sie den Link und klicken Sie auf den Tab Protokolle, um Ihre Arbeitslast in Echtzeit zu verfolgen.

JAX in MaxText debuggen

Verwenden Sie zusätzliche XPK-Befehle, um zu ermitteln, warum der Cluster oder die Arbeitslast nicht ausgeführt wird:

  • XPK-Arbeitslastliste
  • XPK-Inspektor
  • Aktivieren Sie beim Erstellen der XPK-Arbeitslast mit dem Flag --enable-debug-logs ausführliche Protokolle in Ihren Arbeitslastprotokollen.

JAX auf MaxText mit Vertex AI überwachen

Skalar- und Profildaten über das verwaltete TensorBoard von Vertex AI aufrufen

  1. Erhöhen Sie die Anzahl der Resource Management (CRUD)-Anfragen für die von Ihnen verwendete Zone von 600 auf 5.000. Bei kleinen Arbeitslasten mit weniger als 16 VMs ist das möglicherweise kein Problem.
  2. Installieren Sie Abhängigkeiten wie cloud-accelerator-diagnostics für Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Erstellen Sie Ihren XPK-Cluster mit dem Flag --create-vertex-tensorboard, wie unter Vertex AI TensorBoard erstellen beschrieben. Sie können diesen Befehl auch auf vorhandenen Clustern ausführen.

  4. Erstellen Sie Ihren Vertex AI-Test, wenn Sie Ihre XPK-Arbeitslast mit dem Flag --use-vertex-tensorboard und dem optionalen Flag --experiment-name ausführen. Eine vollständige Liste der Schritte finden Sie unter Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen.

Die Protokolle enthalten einen Link zu einem Vertex AI TensorBoard, ähnlich wie hier:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Sie finden den Link zu Vertex AI TensorBoard auch in der Google Cloud Console. Rufen Sie in der Google Cloud Console Vertex AI-Tests auf. Wählen Sie im Drop-down-Menü die gewünschte Region aus.

Das TensorBoard-Verzeichnis wird auch in den Cloud Storage-Bucket geschrieben, den Sie mit ${BASE_OUTPUT_DIR} angegeben haben.

XPK-Arbeitslasten löschen

Verwenden Sie den Befehl xpk workload delete, um eine oder mehrere Arbeitslasten basierend auf dem Jobpräfix oder dem Jobstatus zu löschen. Dieser Befehl kann nützlich sein, wenn Sie XPK-Arbeitslasten gesendet haben, die nicht mehr ausgeführt werden müssen, oder wenn Jobs in der Warteschlange hängen.

XPK-Cluster löschen

Verwenden Sie den Befehl xpk cluster delete, um einen Cluster zu löschen:

python3 xpk.py cluster delete --cluster CLUSTER_NAME --zone $ZONE --project $PROJECT_ID

Llama und PyTorch

In dieser Anleitung wird beschrieben, wie Sie Llama-Modelle mit PyTorch/XLA auf einer TPU v6e mit dem Dataset WikiText trainieren. Außerdem können Nutzer hier auf PyTorch-TPU-Modellbeschreibungen als Docker-Images zugreifen.

Installation

Installieren Sie den pytorch-tpu/transformers-Fork von Hugging Face Transformers und die Abhängigkeiten in einer virtuellen Umgebung:

git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets
pip3 install evaluate
pip3 install scikit-learn
pip3 install accelerate

Modellkonfigurationen einrichten

Der Trainingsbefehl im nächsten Abschnitt Modellscript erstellen verwendet zwei JSON-Konfigurationsdateien, um Modellparameter und die FSDP-Konfiguration (Fully Sharded Data Parallel) zu definieren. Das FSDP-Sharding wird verwendet, damit die Modellgewichte während des Trainings zu einer größeren Batchgröße passen. Beim Training mit kleineren Modellen reicht es möglicherweise aus, nur Datenparallelität zu verwenden und die Gewichte auf jedem Gerät zu replizieren. Weitere Informationen zum Aufteilen von Tensoren auf Geräte in PyTorch/XLA finden Sie im PyTorch/XLA SPMD-Nutzerhandbuch.

  1. Erstellen Sie die Konfigurationsdatei für die Modellparameter. Im Folgenden finden Sie die Modellparameterkonfiguration für Llama3-8B. Für andere Modelle finden Sie die Konfiguration auf Hugging Face. Siehe beispielsweise die Llama2-7B-Konfiguration.

    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
  2. Erstellen Sie die FSDP-Konfigurationsdatei:

    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }

    Weitere Informationen zu FSDP finden Sie unter FSDPv2.

  3. Laden Sie die Konfigurationsdateien mit dem folgenden Befehl auf Ihre TPU-VMs hoch:

        gcloud alpha compute tpus tpu-vm scp YOUR_CONFIG_FILE.json $TPU_NAME:. \
            --worker=all \
            --project=$PROJECT \
            --zone $ZONE

    Sie können die Konfigurationsdateien auch in Ihrem aktuellen Arbeitsverzeichnis erstellen und das Flag --base-docker-image in XPK verwenden.

Modellscript erstellen

Erstellen Sie Ihr Modellscript und geben Sie die Konfigurationsdatei für die Modellparameter mit dem Flag --config_name und die FSDP-Konfigurationsdatei mit dem Flag --fsdp_config an. Sie führen dieses Script im nächsten Abschnitt Modell ausführen auf Ihrer TPU aus. Führen Sie das Modellskript noch nicht aus.

    PJRT_DEVICE=TPU
    XLA_USE_SPMD=1
    ENABLE_PJRT_COMPATIBILITY=true
    # Optional variables for debugging:
    XLA_IR_DEBUG=1
    XLA_HLO_DEBUG=1
    PROFILE_EPOCH=0
    PROFILE_STEP=3
    PROFILE_DURATION_MS=100000
    PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path
    python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 8 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/config-8B.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp_config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20

Modell ausführen

Führen Sie das Modell mit dem im vorherigen Schritt erstellten Script aus: Modellscript erstellen.

Wenn Sie eine TPU-VM mit einem einzelnen Host verwenden (z. B. v6e-4), können Sie den Trainingsbefehl direkt auf der TPU-VM ausführen. Wenn Sie eine TPU-VM mit mehreren Hosts verwenden, können Sie das Script mit dem folgenden Befehl gleichzeitig auf allen Hosts ausführen:

gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
    --zone $ZONE \
    --worker=all \
    --command=YOUR_COMMAND

Fehlerbehebung bei PyTorch/XLA

Wenn Sie die optionalen Variablen für das Debuggen im vorherigen Abschnitt festgelegt haben, wird das Profil für das Modell an dem Speicherort gespeichert, der in der Variablen PROFILE_LOGDIR angegeben ist. Sie können die dort gespeicherte xplane.pb-Datei extrahieren und mit tensorboard die Profile in Ihrem Browser anzeigen lassen. Folgen Sie dazu der TensorBoard-Anleitung. Wenn PyTorch/XLA nicht wie erwartet funktioniert, lesen Sie den Leitfaden zur Fehlerbehebung. Dort finden Sie Vorschläge zum Debuggen, Profilieren und Optimieren Ihrer Modelle.

DLRM DCN v2-Anleitung

In dieser Anleitung wird beschrieben, wie Sie das DLRM DCN v2-Modell auf einer TPU v6e trainieren.

Wenn Sie mehrere Hosts verwenden, setzen Sie tpu-runtime mit der entsprechenden TensorFlow-Version zurück. Führen Sie dazu den folgenden Befehl aus. Wenn Sie die Anwendung auf einem einzelnen Host ausführen, müssen Sie die folgenden beiden Befehle nicht ausführen.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

SSH-Verbindung zu worker-0 herstellen

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

TPU-Namen festlegen

export TPU_NAME=${TPU_NAME}

DLRM v2 ausführen

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

Führen Sie script.sh aus.

chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Die folgenden Flags sind erforderlich, um Empfehlungsarbeitslasten (DLRM DCN) auszuführen:

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

Benchmarking-Ergebnisse

Der folgende Abschnitt enthält Benchmarking-Ergebnisse für DLRM DCN v2 und MaxDiffusion auf v6e.

DLRM DCN v2

Das DLRM DCN v2-Trainingsskript wurde in verschiedenen Skalen ausgeführt. Die Durchlaufraten finden Sie in der folgenden Tabelle.

v6e-64 v6e-128 v6e-256
Trainingsschritte 7000 7000 7000
Globale Batchgröße 131.072 262144 524.288
Durchsatz (Beispiele/Sek.) 2975334 5111808 10066329

MaxDiffusion

Wir haben das Trainingsskript für MaxDiffusion auf einem v6e-4, einem v6e-16 und einem 2xv6e-16 ausgeführt. Die Durchlaufraten finden Sie in der folgenden Tabelle.

v6e-4 v6e-16 Zwei v6e-16
Trainingsschritte 0.069 0,073 0,13
Globale Batchgröße 8 32 64
Durchsatz (Beispiele/Sek.) 115,9 438,4 492,3

Sammlungen

In Version 6e wird die neue Funktion „Sammlungen“ für Nutzer eingeführt, die Bereitstellungslasten ausführen. Die Sammlungsfunktion gilt nur für Version 6e.

Mit Sammlungen können Sie Google Cloud mitteilen, welche Ihrer TPU-Knoten zu einer Bereitstellungsarbeitslast gehören. So kann die zugrunde liegende Google Cloud-Infrastruktur Unterbrechungen, die im normalen Betrieb auf Trainingslasten angewendet werden können, begrenzen und optimieren.

Sammlungen aus der Cloud TPU API verwenden

Eine Sammlung mit einem einzelnen Host in der Cloud TPU API ist eine Ressourcenwarteschlange, für die ein spezielles Flag (--workload-type = availability-optimized) festgelegt ist, um der zugrunde liegenden Infrastruktur anzugeben, dass sie für die Bereitstellung von Arbeitslasten verwendet werden soll.

Mit dem folgenden Befehl wird eine Sammlung mit einem einzelnen Host mithilfe der Cloud TPU API bereitgestellt:

gcloud alpha compute tpus queued-resources create COLLECTION_NAME \
   --project=project name \
   --zone=zone name \
   --accelerator-type=accelerator type \
   --node-count=number of nodes \
   --workload-type=availability-optimized

Überwachen und profilieren

Cloud TPU v6e unterstützt Monitoring und Profiling mit denselben Methoden wie frühere Cloud TPU-Generationen. Weitere Informationen zum Monitoring finden Sie unter TPU-VMs überwachen.