Trillium (v6e) – Einführung
In dieser Dokumentation, in der TPU API und in den Protokollen wird v6e für Trillium verwendet. v6e steht für die sechste Generation von TPU von Google.
Mit 256 Chips pro Pod hat v6e viele Ähnlichkeiten mit v5e. Dieses System ist für das Training, die Feinabstimmung und das Bereitstellen von Transformern, Text-zu-Bild-Modellen und CNNs (Convolutional Neural Networks) optimiert.
v6e-Systemarchitektur
Informationen zur Cloud TPU-Konfiguration finden Sie in der v6e-Dokumentation.
In diesem Dokument liegt der Schwerpunkt auf der Einrichtung des Modelltrainings mit den Frameworks JAX, PyTorch oder TensorFlow. Mit jedem Framework können Sie TPUs mithilfe von Ressourcen in der Warteschlange oder der Google Kubernetes Engine (GKE) bereitstellen. Die GKE-Einrichtung kann mit XPK- oder GKE-Befehlen erfolgen.
Google Cloud-Projekt vorbereiten
- Melden Sie sich in Ihrem Google-Konto an. Wenn Sie noch kein Konto haben, melden Sie sich hier für ein neues Konto an.
- Wählen Sie in der Google Cloud Console auf der Seite der Projektauswahl ein Cloud-Projekt aus oder erstellen Sie eines.
- Aktivieren Sie die Abrechnung für Ihr Google Cloud-Projekt. Für die gesamte Nutzung von Google Cloud ist eine Abrechnung erforderlich.
- Installieren Sie die gcloud-Alphakomponenten.
Führen Sie den folgenden Befehl aus, um die neueste Version der
gcloud
-Komponenten zu installieren.gcloud components update
Aktivieren Sie die TPU API mit dem folgenden
gcloud
-Befehl in Cloud Shell. Sie können sie auch über die Google Cloud Console aktivieren.gcloud services enable tpu.googleapis.com
Berechtigungen mit dem TPU-Dienstkonto für die Compute Engine API aktivieren
Dienstkonten ermöglichen dem Cloud TPU-Dienst, auf andere Google Cloud-Dienste zuzugreifen. Ein nutzerverwaltetes Dienstkonto ist eine empfohlene Google Cloud-Praxis. Folgen Sie diesen Anleitungen, um Rollen zu erstellen und zu gewähren. Folgende Rollen sind erforderlich:
- TPU-Administrator
- Storage-Administrator
- Log-Autor
- Monitoring-Messwert-Autor
a. Richten Sie XPK-Berechtigungen mit Ihrem Nutzerkonto für GKE ein: XPK.
Erstellen Sie Umgebungsvariablen für die Projekt-ID und die Zone.
gcloud auth login gcloud config set project ${PROJECT_ID} gcloud config set compute/zone ${ZONE}
Erstellen Sie eine Dienstidentität für die TPU-VM.
gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
Sichere Kapazität
Wenden Sie sich an Ihren Cloud TPU-Supportmitarbeiter, um ein TPU-Kontingent anzufordern und Fragen zur Kapazität zu stellen.
Cloud TPU-Umgebung bereitstellen
v6e-TPUs können mit GKE, mit GKE und XPK (einem Befehlszeilen-Wrapper für GKE) oder als in der Warteschlange befindliche Ressourcen bereitgestellt und verwaltet werden.
Vorbereitung
- Prüfen Sie, ob für Ihr Projekt ein ausreichendes
TPUS_PER_TPU_FAMILY
-Kontingent vorhanden ist. Dieses gibt die maximale Anzahl der Chips an, auf die Sie in Ihrem Google Cloud-Projekt zugreifen können. - v6e wurde mit der folgenden Konfiguration getestet:
- Python
3.10
oder höher - Nightly-Softwareversionen:
- JAX pro Nacht
0.4.32.dev20240912
- nightly LibTPU
0.1.dev20240912+nightly
- JAX pro Nacht
- Stabile Softwareversionen:
- JAX + JAX-Bibliothek der Version 0.4.35
- Python
- Prüfen Sie, ob Ihr Projekt ein ausreichendes TPU-Kontingent für Folgendes hat:
- TPU-VM-Kontingent
- Kontingent für IP-Adressen
- Hyperdisk-Balance-Kontingent
- Nutzerberechtigungen für Projekte
- Wenn Sie GKE mit XPK verwenden, finden Sie unter Cloud Console-Berechtigungen für das Nutzer- oder Dienstkonto Informationen zu den Berechtigungen, die zum Ausführen von XPK erforderlich sind.
Umgebungsvariablen
Erstellen Sie in Cloud Shell die folgenden Umgebungsvariablen:
export NODE_ID=TPU_NODE_ID # TPU name export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=v6e-16 export ZONE=us-central2-b export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID export VALID_DURATION=VALID_DURATION # Additional environment variable needed for Multislice: export NUM_SLICES=NUM_SLICES # Use a custom network for better performance as well as to avoid having the # default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Beschreibung der Befehls-Flags
Variable | Beschreibung |
NODE_ID | Die vom Nutzer zugewiesene ID der TPU, die erstellt wird, wenn die anstehende Ressourcenanfrage zugewiesen wird. |
PROJECT_ID | Name des Google Cloud-Projekts. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter |
ZONE | Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen. |
ACCELERATOR_TYPE | Weitere Informationen finden Sie unter Beschleunigertypen. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Das ist die E-Mail-Adresse Ihres Dienstkontos. Sie finden sie in der Google Cloud Console unter „IAM“ -> „Dienstkonten“.
Beispiel: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com |
NUM_SLICES | Die Anzahl der zu erstellenden Scheiben (nur für Mehrfachaufnahmen erforderlich) |
QUEUED_RESOURCE_ID | Die vom Nutzer zugewiesene Text-ID der anstehenden Ressourcenanfrage. |
VALID_DURATION | Die Dauer, für die die angeforderte Ressource gültig ist. |
NETWORK_NAME | Der Name eines sekundären Netzwerks, das verwendet werden soll. |
NETWORK_FW_NAME | Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll. |
Optimierungen der Netzwerkleistung
Für die beste Leistung verwenden Sie ein Netzwerk mit 8.896 MTU (maximale Übertragungseinheit).
Standardmäßig bietet eine Virtual Private Cloud (VPC) nur eine MTU von 1.460 Byte, was zu einer suboptimalen Netzwerkleistung führt. Sie können die MTU eines VPC-Netzwerk auf einen beliebigen Wert zwischen 1.300 Byte und 8.896 Byte (einschließlich) festlegen. Gängige benutzerdefinierte MTU-Größen sind 1.500 Byte (Ethernet-Standard) oder 8.896 Byte (maximal möglich). Weitere Informationen finden Sie unter Gültige MTU-VPC-Netzwerk-Netzwerke.
Weitere Informationen zum Ändern der MTU-Einstellung für ein vorhandenes oder Standardnetzwerk finden Sie unter MTU-Einstellung eines VPC-Netzwerks ändern.
Im folgenden Beispiel wird ein Netzwerk mit 8.896 MTU erstellt.
export RESOURCE_NAME=RESOURCE_NAME export NETWORK_NAME=${RESOURCE_NAME} export NETWORK_FW_NAME=${RESOURCE_NAME} export PROJECT=X gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} \
Mehrere NICs verwenden (Option für Multi-Slice)
Die folgenden Umgebungsvariablen sind für ein sekundäres Subnetz erforderlich, wenn Sie eine Multislice-Umgebung verwenden.
export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2
Verwenden Sie die folgenden Befehle, um eine benutzerdefinierte IP-Weiterleitung für das Netzwerk und das Subnetz zu erstellen.
gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
--bgp-routing-mode=regional --subnet-mode=custom --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
--network="${NETWORK_NAME_2}" \
--range=10.10.0.0/18 --region="${REGION}" \
--project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
--network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \
--project="${PROJECT}" \
--network="${NETWORK_NAME_2}" \
--region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
--router="${ROUTER_NAME}" \
--region="${REGION}" \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project="${PROJECT}" \
--enable-logging
Nachdem ein Multi-Netzwerk-Speicherbereich erstellt wurde, können Sie prüfen, ob beide NICs verwendet werden, indem Sie --command ifconfig
als Teil der XPK-Arbeitslast ausführen. Sehen Sie sich dann die gedruckte Ausgabe dieser XPK-Arbeitslast in den Cloud Console-Protokollen an und prüfen Sie, ob sowohl eth0 als auch eth1 eine mtu=8896 haben.
python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ (--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}) \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command "ifconfig"
Prüfen Sie, ob sowohl eth0 als auch eth1 die mtu=8.896 haben. Sie können prüfen, ob Multi-NIC verwendet wird, indem Sie den Befehl „–command ifconfig“ als Teil der XPK-Arbeitslast ausführen. Sehen Sie sich dann die gedruckte Ausgabe dieser XPK-Arbeitslast in den Cloud Console-Protokollen an und prüfen Sie, ob sowohl eth0 als auch eth1 eine mtu=8896 haben.
Verbesserte TCP-Einstellungen
Für TPUs, die über die Benutzeroberfläche für Ressourcen in der Warteschlange erstellt wurden, können Sie den folgenden Befehl ausführen, um die Netzwerkleistung zu verbessern, indem Sie die Standard-TCP-Einstellungen für rto_min
und quickack
ändern.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "$PROJECT" --zone "${ZONE}" \ --command='ip route show | while IFS= read -r route; do if ! echo $route | \ grep -q linkdown; then sudo ip route change ${route/lock/} rto_min 5ms quickack 1; fi; done' \ --worker=all
Bereitstellung mit Ressourcen in der Warteschlange (Cloud TPU API)
Die Kapazität kann mit dem Befehl create
für Ressourcen in der Warteschlange bereitgestellt werden.
Erstellen Sie eine Anfrage für eine in die Warteschlange gestellte TPU-Ressource.
Das
--reserved
-Flag ist nur für reservierte Ressourcen erforderlich, nicht für On-Demand-Ressourcen.gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ [--reserved]
Wenn die Anfrage für die Ressourcenwarteschlange erfolgreich erstellt wurde, hat das Feld „response“ entweder den Status „WAITING_FOR_RESOURCES“ oder „FAILED“. Wenn sich die angeforderte Ressource in der Warteschlange im Status „WAITING_FOR_RESOURCES“ (Warten auf Ressourcen) befindet, wurde sie in die Warteschlange gestellt und wird bereitgestellt, sobald genügend TPU-Kapazität verfügbar ist. Wenn die Anfrage für die Ressourcenwarteschlange den Status „FAILED“ (Fehlgeschlagen) hat, wird der Grund für den Fehler in der Ausgabe angezeigt. Die anstehende Ressourcenanfrage läuft ab, wenn innerhalb der angegebenen Dauer kein v6e bereitgestellt wird. Der Status ändert sich dann in „FAILED“. Weitere Informationen finden Sie in der öffentlichen Dokumentation zu Ressourcen in der Warteschlange.
Wenn Ihre in die Warteschlange gestellte Ressourcenanfrage den Status „AKTIV“ hat, können Sie über SSH eine Verbindung zu Ihren TPU-VMs herstellen. Verwenden Sie die Befehle
list
oderdescribe
, um den Status der in der Warteschlange befindlichen Ressource abzufragen.gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE}
Wenn sich die erwartete Ressource im Status „AKTIV“ befindet, sieht die Ausgabe in etwa so aus:
state: state: ACTIVE
TPU-VMs verwalten Informationen zu den Optionen zum Verwalten Ihrer TPU-VMs finden Sie unter TPU-VMs verwalten.
Über SSH eine Verbindung zu TPU-VMs herstellen
Sie können Binärdateien auf jeder TPU-VM in Ihrem TPU-Speicherplatz installieren und Code ausführen. Im Abschnitt VM-Typen erfahren Sie, wie viele VMs Ihr Slice haben wird.
Wenn Sie die Binärdateien installieren oder Code ausführen möchten, können Sie mit dem Befehl
tpu-vm ssh
eine SSH-Verbindung zu einer VM herstellen.gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --node=all # add this flag if you are using Multislice
Wenn Sie über SSH eine Verbindung zu einer bestimmten VM herstellen möchten, verwenden Sie das Flag
--worker
, das einem Index mit dem Wert 0 folgt:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
Wenn Sie Slice-Formen mit mehr als 8 Chips haben, befinden sich mehrere VMs in einem Slice. Verwenden Sie in diesem Fall die Parameter
--worker=all
und--command
in Ihremgcloud alpha compute tpus tpu-vm ssh
-Befehl, um einen Befehl gleichzeitig auf allen VMs auszuführen. Beispiel:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \ --zone ${ZONE} --worker=all \ --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
In die Warteschlange gestellte Ressource löschen
Lösche am Ende der Sitzung eine in die Warteschlange gestellte Ressource oder entferne in die Warteschlange gestellte Ressourcenanfragen, die den Status „FAILED“ (Fehlgeschlagen) haben. Wenn Sie eine in die Warteschlange gestellte Ressource löschen möchten, löschen Sie in zwei Schritten den Ausschnitt und dann die Anfrage für die in die Warteschlange gestellte Ressource:
gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \ --zone=${ZONE} --quiet gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet
gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
GKE mit v6e verwenden
Wenn Sie GKE-Befehle mit v6e verwenden, können Sie Kubernetes-Befehle oder XPK verwenden, um TPUs bereitzustellen und Modelle zu trainieren oder bereitzustellen. Unter TPUs in GKE planen erfahren Sie, wie Sie GKE mit TPUs und v6e verwenden.
Framework einrichten
In diesem Abschnitt wird die allgemeine Einrichtung für das Training von ML-Modellen mit den Frameworks JAX, PyTorch oder TensorFlow beschrieben. Sie können TPUs mit Ressourcen in der Warteschlange oder mit GKE bereitstellen. Die GKE-Einrichtung kann mit XPK- oder Kubernetes-Befehlen erfolgen.
JAX mit Ressourcen in der Warteschlange einrichten
Mit gcloud alpha compute tpus tpu-vm ssh
können Sie JAX gleichzeitig auf allen TPU-VMs in Ihrem Slice oder Ihren Slices installieren. Fügen Sie für Multislice --node=all
hinzu.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'
Sie können den folgenden Python-Code ausführen, um zu prüfen, wie viele TPU-Kerne in Ihrem Snippet verfügbar sind und ob alles richtig installiert ist. Die hier gezeigten Ausgaben wurden mit einem v6e-16-Snippet erstellt:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
Die Ausgabe sieht in etwa so aus:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count() gibt die Gesamtzahl der Chips im angegebenen Slice an. jax.local_device_count() gibt die Anzahl der Chips an, auf die eine einzelne VM in diesem Slice zugreifen kann.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 &&
&& pip install -r requirements.txt && pip install . '
Fehlerbehebung bei JAX-Einrichtungen
Als allgemeinen Tipp können wir Ihnen empfehlen, im Manifest Ihrer GKE-Arbeitslast ausführliches Logging zu aktivieren. Reichen Sie die Protokolle dann beim GKE-Support ein.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Fehlermeldungen
no endpoints available for service 'jobset-webhook-service'
Dieser Fehler bedeutet, dass der Jobsatz nicht richtig installiert wurde. Prüfen Sie, ob die Kubernetes-Pods für die Bereitstellung von „jobset-controller-manager“ ausgeführt werden. Weitere Informationen finden Sie in der Dokumentation zur Fehlerbehebung bei JobSets.
TPU initialization failed: Failed to connect
Die GKE-Knotenversion muss 1.30.4-gke.1348000 oder höher sein. GKE 1.31 wird nicht unterstützt.
Einrichtung für PyTorch
In diesem Abschnitt wird beschrieben, wie Sie PJRT in v6e mit PyTorch/XLA verwenden. Python 3.10 ist die empfohlene Python-Version.
PyTorch mit GKE und XPK einrichten
Sie können den folgenden Docker-Container mit XPK verwenden, in dem die PyTorch-Abhängigkeiten bereits installiert sind:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Führen Sie den folgenden Befehl aus, um eine XPK-Arbeitslast zu erstellen:
python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -$NUM_SLICES \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES} \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
Mit --base-docker-image
wird ein neues Docker-Image erstellt, in das das aktuelle Arbeitsverzeichnis eingebunden ist.
PyTorch mit Ressourcen in der Warteschlange einrichten
Führen Sie die folgenden Schritte aus, um PyTorch mit anstehenden Ressourcen zu installieren und ein kleines Script auf v6e auszuführen.
Installieren Sie die Abhängigkeiten über SSH, um auf die VMs zuzugreifen.
Fügen Sie für Multislice --node=all
hinzu:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt install -y libopenblas-base pip3 \
install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Leistung von Modellen mit großen, häufigen Zuweisungen verbessern
Bei Modellen mit großen, häufigen Zuweisungen haben wir festgestellt, dass die Verwendung von tcmalloc
die Leistung im Vergleich zur Standardimplementierung von malloc
erheblich verbessert. Daher ist tcmalloc
die Standard-malloc
auf der TPU-VM. Je nach Arbeitslast kann tcmalloc
jedoch zu einer Verlangsamung führen, z. B. bei DLRM mit sehr großen Zuweisungen für die Einbettungstabellen. In diesem Fall können Sie versuchen, die folgende Variable zurückzusetzen und stattdessen die Standardeinstellung malloc
zu verwenden:
unset LD_PRELOAD
Mit einem Python-Script eine Berechnung auf einer v6e-VM ausführen:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--project ${PROJECT_ID} \
--zone ${ZONE} --worker all --command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'
Dadurch wird eine Ausgabe generiert, die etwa so aussieht:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
Einrichtung für TensorFlow
Für die öffentliche Vorschau von v6e wird nur die Laufzeitversion „tf-nightly“ unterstützt.
Sie können tpu-runtime
mit der mit v6e kompatiblen TensorFlow-Version zurücksetzen, indem Sie die folgenden Befehle ausführen:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Verwenden Sie SSH, um auf worker-0 zuzugreifen:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE}
Installieren Sie TensorFlow auf worker-0:
sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Exportieren Sie die Umgebungsvariable TPU_NAME
:
export TPU_NAME=v6e-16
Mit dem folgenden Python-Script können Sie prüfen, wie viele TPU-Kerne in Ihrem Slice verfügbar sind, und testen, ob alles richtig installiert ist. Die angezeigten Ausgaben wurden mit einem v6e-16-Speicherplatz erstellt:
import TensorFlow as tf
print("TensorFlow version " + tf.__version__)
@tf.function
def add_fn(x,y):
z = x + y
return z
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
x = tf.constant(1.)
y = tf.constant(1.)
z = strategy.run(add_fn, args=(x,y))
print(z)
Die Ausgabe sieht in etwa so aus:
PerReplica:{
0: tf.Tensor(2.0, shape=(), dtype=float32),
1: tf.Tensor(2.0, shape=(), dtype=float32),
2: tf.Tensor(2.0, shape=(), dtype=float32),
3: tf.Tensor(2.0, shape=(), dtype=float32),
4: tf.Tensor(2.0, shape=(), dtype=float32),
5: tf.Tensor(2.0, shape=(), dtype=float32),
6: tf.Tensor(2.0, shape=(), dtype=float32),
7: tf.Tensor(2.0, shape=(), dtype=float32)
}
v6e mit SkyPilot
Sie können TPU v6e mit SkyPilot verwenden. Führen Sie die folgenden Schritte aus, um SkyPilot v6e-bezogene Standort-/Preisinformationen hinzuzufügen.
Fügen Sie am Ende von
~/.sky/catalogs/v5/gcp/vms.csv
Folgendes hinzu :,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
Geben Sie die folgenden Ressourcen in einer YAML-Datei an:
# tpu_v6.yaml resources: accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use accelerator_args: runtime_version: v2-alpha-tpuv6e # Official suggested runtime
Cluster mit TPU v6e starten:
sky launch tpu_v6.yaml -c tpu_v6
Stellen Sie eine SSH-Verbindung zur TPU v6e her:
ssh tpu_v6
Anleitungen für Inferenz
In den folgenden Abschnitten finden Sie Anleitungen zum Bereitstellen von MaxText- und PyTorch-Modellen mit JetStream sowie zum Bereitstellen von MaxDiffusion-Modellen auf TPU v6e.
MaxText auf JetStream
In dieser Anleitung erfahren Sie, wie Sie mit JetStream MaxText-Modelle (JAX) auf TPU v6e bereitstellen. JetStream ist eine durchsatz- und speicheroptimierte Engine für die LLM-Inferenz (Large Language Model) auf XLA-Geräten (TPUs). In dieser Anleitung führen Sie den Inferenz-Benchmark für das Llama2-7B-Modell aus.
Hinweise
TPU v6e mit 4 Chips erstellen:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Stellen Sie eine SSH-Verbindung zur TPU her:
gcloud compute tpus tpu-vm ssh TPU_NAME
Anleitung ausführen
Folgen Sie der Anleitung im GitHub-Repository, um JetStream und MaxText einzurichten, die Modell-Checkpunkte zu konvertieren und den Inferenz-Benchmark auszuführen.
Bereinigen
Löschen Sie die TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async
vLLM auf PyTorch TPU
Unten finden Sie eine einfache Anleitung für den Einstieg in vLLM auf einer TPU-VM. In den nächsten Tagen veröffentlichen wir einen GKE-Leitfaden mit Best Practices für die Bereitstellung von vLLM auf Trillium in der Produktion. Mehr dazu demnächst!
Hinweise
TPU v6e mit 4 Chips erstellen:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Beschreibung der Befehls-Flags
Variable Beschreibung NODE_ID Die vom Nutzer zugewiesene ID der TPU, die erstellt wird, wenn die anstehende Ressourcenanfrage zugewiesen wird. PROJECT_ID Name des Google Cloud-Projekts. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter ZONE Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen. ACCELERATOR_TYPE Weitere Informationen finden Sie unter Beschleunigertypen. RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Das ist die E-Mail-Adresse Ihres Dienstkontos. Sie finden sie in der Google Cloud Console unter „IAM“ -> „Dienstkonten“. Beispiel: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com
Stellen Sie eine SSH-Verbindung zur TPU her:
gcloud compute tpus tpu-vm ssh TPU_NAME
Create a Conda environment
(Recommended) Create a new conda environment for vLLM:
conda create -n vllm python=3.10 -y conda activate vllm
vLLM auf TPU einrichten
Klonen Sie das vLLM-Repository und wechseln Sie zum vLLM-Verzeichnis:
git clone https://github.com/vllm-project/vllm.git && cd vllm
Bereinigen Sie die vorhandenen Pakete „torch“ und „torch-xla“:
pip uninstall torch torch-xla -y
Installieren Sie PyTorch und PyTorch XLA:
pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
Installieren Sie JAX und Pallas:
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Installieren Sie andere Build-Abhängigkeiten:
pip install -r requirements-tpu.txt VLLM_TARGET_DEVICE="tpu" python setup.py develop sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
Zugriff auf das Modell erhalten
Sie müssen die Einwilligungsvereinbarung unterzeichnen, um die Llama3-Modellfamilie im HuggingFace-Repository verwenden zu können.
Generieren Sie ein neues Hugging Face-Token, falls Sie noch keines haben:
- Klicken Sie auf Profil > Einstellungen > Zugriffstokens.
- Wählen Sie Neues Token aus.
- Geben Sie einen Namen Ihrer Wahl und eine Rolle von mindestens
Read
an. - Wählen Sie Token generieren aus.
Kopieren Sie das generierte Token in die Zwischenablage, legen Sie es als Umgebungsvariable fest und authentifizieren Sie sich mit der huggingface-cli:
export TOKEN='' git config --global credential.helper store huggingface-cli login --token $TOKEN
Benchmarking-Daten herunterladen
Erstellen Sie das Verzeichnis „/data“ und laden Sie das ShareGPT-Dataset von Hugging Face herunter.
mkdir ~/data && cd ~/data wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
vLLM-Server starten
Mit dem folgenden Befehl werden die Modellgewichte aus dem Hugging Face Model Hub in das Verzeichnis „/tmp“ der TPU-VM heruntergeladen, eine Reihe von Eingabeformen vorab kompiliert und die Modellkompilierung in ~/.cache/vllm/xla_cache
geschrieben.
Weitere Informationen finden Sie in der vLLM-Dokumentation.
cd ~/vllm
vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &
vLLM-Benchmarks ausführen
Führen Sie das vLLM-Benchmarking-Script aus:
python benchmarks/benchmark_serving.py \
--backend vllm \
--model "meta-llama/Meta-Llama-3.1-8B" \
--dataset-name sharegpt \
--dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json \
--num-prompts 1000
Bereinigen
Löschen Sie die TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async
PyTorch auf JetStream
In dieser Anleitung erfahren Sie, wie Sie mit JetStream PyTorch-Modelle auf TPU v6e bereitstellen. JetStream ist eine durchsatz- und speicheroptimierte Engine für die LLM-Inferenz (Large Language Model) auf XLA-Geräten (TPUs). In dieser Anleitung führen Sie den Inferenz-Benchmark für das Llama2-7B-Modell aus.
Hinweise
TPU v6e mit 4 Chips erstellen:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Stellen Sie eine SSH-Verbindung zur TPU her:
gcloud compute tpus tpu-vm ssh TPU_NAME
Anleitung ausführen
Folgen Sie der Anleitung im GitHub-Repository, um JetStream-PyTorch einzurichten, die Modell-Checkpunkte zu konvertieren und den Inferenz-Benchmark auszuführen.
Bereinigen
Löschen Sie die TPU:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--force \
--async
MaxDiffusion-Inferenz
In dieser Anleitung wird beschrieben, wie Sie MaxDiffusion-Modelle auf TPU v6e bereitstellen. In dieser Anleitung generieren Sie Bilder mit dem Stable Diffusion XL-Modell.
Hinweise
TPU v6e mit 4 Chips erstellen:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Stellen Sie eine SSH-Verbindung zur TPU her:
gcloud compute tpus tpu-vm ssh TPU_NAME
Conda-Umgebung erstellen
Erstellen Sie ein Verzeichnis für Miniconda:
mkdir -p ~/miniconda3
Laden Sie das Miniconda-Installationsskript herunter:
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
Installieren Sie Miniconda:
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
Entfernen Sie das Miniconda-Installationsskript:
rm -rf ~/miniconda3/miniconda.sh
Fügen Sie Miniconda der Variablen
PATH
hinzu:export PATH="$HOME/miniconda3/bin:$PATH"
Laden Sie
~/.bashrc
neu, um die Änderungen auf die VariablePATH
anzuwenden:source ~/.bashrc
Erstellen Sie eine neue Conda-Umgebung:
conda create -n tpu python=3.10
Aktivieren Sie die Conda-Umgebung:
source activate tpu
MaxDiffusion einrichten
Klonen Sie das MaxDiffusion-Repository und wechseln Sie zum MaxDiffusion-Verzeichnis:
https://github.com/google/maxdiffusion.git && cd maxdiffusion
Wechseln Sie zum
mlperf-4.1
-Zweig:git checkout mlperf4.1
Installieren Sie MaxDiffusion:
pip install -e .
Installieren Sie die Abhängigkeiten:
pip install -r requirements.txt
JAX installieren:
pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Bilder erstellen
Legen Sie Umgebungsvariablen fest, um die TPU-Laufzeit zu konfigurieren:
LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
Bilder mit dem Prompt und den Konfigurationen generieren, die in
src/maxdiffusion/configs/base_xl.yml
definiert sind:python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"
Bereinigen
Löschen Sie die TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async
Trainingsanleitungen
In den folgenden Abschnitten finden Sie Anleitungen zum Trainieren von MaxText.
MaxDiffusion- und PyTorch-Modelle auf TPU v6e
MaxText und MaxDiffusion
In den folgenden Abschnitten wird der Trainingszyklus der Modelle MaxText und MaxDiffusion beschrieben.
Im Allgemeinen sind dies die allgemeinen Schritte:
- Erstellen Sie das Basis-Image der Arbeitslast.
- Führen Sie die Arbeitslast mit XPK aus.
- Erstellen Sie den Trainingsbefehl für die Arbeitslast.
- Stellen Sie die Arbeitslast bereit.
- Arbeitslast verfolgen und Messwerte ansehen
- Löschen Sie die XPK-Arbeitslast, wenn sie nicht benötigt wird.
- Löschen Sie den XPK-Cluster, wenn er nicht mehr benötigt wird.
Basis-Image erstellen
Installieren Sie MaxText oder MaxDiffusion und erstellen Sie das Docker-Image:
Klonen Sie das gewünschte Repository und wechseln Sie zum Verzeichnis des Repositorys:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
Konfigurieren Sie Docker so, dass die Google Cloud CLI verwendet wird:
gcloud auth configure-docker
Erstellen Sie das Docker-Image mit dem folgenden Befehl oder mit JAX Stable Stack. Weitere Informationen zum JAX Stable Stack finden Sie unter Docker-Image mit JAX Stable Stack erstellen.
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
Wenn Sie die Arbeitslast von einem Computer aus starten, auf dem das Image nicht lokal erstellt wurde, laden Sie das Image hoch:
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
Docker-Image mit JAX Stable Stack erstellen
Sie können die Docker-Images „MaxText“ und „MaxDiffusion“ mit dem Basis-Image „JAX Stable Stack“ erstellen.
Der JAX Stable Stack bietet eine einheitliche Umgebung für MaxText und MaxDiffusion, da JAX mit Kernpaketen wie orbax
, flax
und optax
sowie einer gut qualifizierten libtpu.so kombiniert wird, die TPU-Programm-Dienstprogramme und andere wichtige Tools antreibt. Diese Bibliotheken werden auf Kompatibilität getestet, um eine stabile Grundlage für das Erstellen und Ausführen von MaxText und MaxDiffusion zu schaffen und potenzielle Konflikte aufgrund von inkompatiblen Paketversionen zu vermeiden.
Der stabile JAX-Stack enthält eine vollständig veröffentlichte und qualifizierte libtpu.so, die Kernbibliothek, die die Kompilierung, Ausführung und ICI-Netzwerkkonfiguration von TPU-Programmen steuert. Der libtpu-Release ersetzt den bisher von JAX verwendeten Nightly-Build und sorgt mit Qualifikationstests auf PJRT-Ebene in HLO/StableHLO-IRs für eine konsistente Funktionsweise von XLA-Berechnungen auf TPUs.
Wenn Sie das Docker-Image „MaxText“ und „MaxDiffusion“ mit dem JAX Stable Stack erstellen möchten, legen Sie beim Ausführen des docker_build_dependency_image.sh
-Scripts die Variable MODE
auf stable_stack
und die Variable BASEIMAGE
auf das gewünschte Basis-Image fest.
Im folgenden Beispiel wird us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
als Basisbild angegeben:
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
Eine Liste der verfügbaren JAX Stable Stack-Basis-Images finden Sie unter JAX Stable Stack-Images in der Artifact Registry.
Arbeitslast mit XPK ausführen
Legen Sie die folgenden Umgebungsvariablen fest, wenn Sie nicht die Standardwerte verwenden, die mit MaxText oder MaxDiffusion festgelegt wurden:
BASE_OUTPUT_DIR=gs://YOUR_BUCKET PER_DEVICE_BATCH_SIZE=2 NUM_STEPS=30 MAX_TARGET_LENGTH=8192
Erstellen Sie Ihr Modellskript, das im nächsten Schritt als Trainingsbefehl kopiert wird. Führen Sie das Modellskript noch nicht aus.
MaxText
MaxText ist ein leistungsstarker, hoch skalierbarer Open-Source-LLM, der in reiner Python- und JAX-Programmierung geschrieben wurde und für das Training und die Inferenz auf Google Cloud TPUs und GPUs ausgerichtet ist.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \ base_output_directory=$BASE_OUTPUT_DIR \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS}" # attention='dot_product'"
Gemma2
Gemma ist eine Familie von Large Language Models (LLMs) mit offenen Gewichten, die von Google DeepMind auf der Grundlage der Gemini-Forschung und -Technologie entwickelt wurden.
# Requires v6e-256 python3 MaxText/train.py MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral ist ein hochmodernes KI-Modell, das von Mistral AI entwickelt wurde und eine sparse MoE-Architektur (Mixture of Experts) verwendet.
python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama ist eine Familie offener Large Language Models (LLMs), die von Meta entwickelt wurden.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=llama3-8b \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ tokenizer_path=assets/tokenizer_llama3.tiktoken \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ attention=flash"
MaxDiffusion
MaxDiffusion ist eine Sammlung von Referenzimplementierungen verschiedener latenter Diffusionsmodelle, die in reiner Python- und JAX-Programmierung geschrieben wurden und auf XLA-Geräten wie Cloud TPUs und GPUs ausgeführt werden. Stable Diffusion ist ein latentes Text-zu-Bild-Modell, das fotorealistische Bilder aus beliebigen Texteingaben generiert.
Sie müssen einen bestimmten Branch installieren, um MaxDiffusion auszuführen:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 && pip install -r requirements.txt && pip install .
Trainingsskript:
cd maxdiffusion && OUT_DIR=${your_own_bucket} python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml \ run_name=v6e-sd2 \ split_head_dim=True \ attention=flash \ train_new_unet=false \ norm_num_groups=16 \ output_dir=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ [dcn_data_parallelism=2] \ enable_profiler=True \ skip_first_n_steps_for_profiler=95 \ max_train_steps=${NUM_STEPS} ] write_metrics=True'
Führen Sie das Modell mit dem Skript aus, das Sie im vorherigen Schritt erstellt haben. Sie müssen entweder das Flag
--base-docker-image
angeben, um das MaxText-Basis-Image zu verwenden, oder das Flag--docker-image
und das gewünschte Image.Optional: Sie können das Debug-Logging aktivieren, indem Sie das Flag
--enable-debug-logs
einfügen. Weitere Informationen finden Sie unter JAX bei MaxText debuggen.Optional: Sie können einen Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen. Fügen Sie dazu das Flag
--use-vertex-tensorboard
hinzu. Weitere Informationen finden Sie unter JAX auf MaxText mit Vertex AI überwachen.python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload ${USER}-xpk-ACCELERATOR_TYPE-NUM_SLICES \ --tpu-type=ACCELERATOR_TYPE \ --num-slices=NUM_SLICES \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command YOUR_MODEL_SCRIPT
Ersetzen Sie die folgenden Variablen:
- CLUSTER_NAME: Der Name Ihres XPK-Clusters.
- ACCELERATOR_TYPE: Die Version und Größe Ihrer TPU. Beispiel:
v6e-256
. - NUM_SLICES: Die Anzahl der TPU-Slices.
- YOUR_MODEL_SCRIPT: Das Modellskript, das als Trainingsbefehl ausgeführt werden soll.
Die Ausgabe enthält einen Link, über den Sie Ihre Arbeitslast verfolgen können, ähnlich dem folgenden:
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
Öffnen Sie den Link und klicken Sie auf den Tab Protokolle, um Ihre Arbeitslast in Echtzeit zu verfolgen.
JAX in MaxText debuggen
Verwenden Sie zusätzliche XPK-Befehle, um zu ermitteln, warum der Cluster oder die Arbeitslast nicht ausgeführt wird:
- XPK-Arbeitslastliste
- XPK-Inspektor
- Aktivieren Sie beim Erstellen der XPK-Arbeitslast mit dem Flag
--enable-debug-logs
ausführliche Protokolle in Ihren Arbeitslastprotokollen.
JAX auf MaxText mit Vertex AI überwachen
Skalar- und Profildaten über das verwaltete TensorBoard von Vertex AI aufrufen
- Erhöhen Sie die Anzahl der Resource Management (CRUD)-Anfragen für die von Ihnen verwendete Zone von 600 auf 5.000. Bei kleinen Arbeitslasten mit weniger als 16 VMs ist das möglicherweise kein Problem.
Installieren Sie Abhängigkeiten wie
cloud-accelerator-diagnostics
für Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Erstellen Sie Ihren XPK-Cluster mit dem Flag
--create-vertex-tensorboard
, wie unter Vertex AI TensorBoard erstellen beschrieben. Sie können diesen Befehl auch auf vorhandenen Clustern ausführen.Erstellen Sie Ihren Vertex AI-Test, wenn Sie Ihre XPK-Arbeitslast mit dem Flag
--use-vertex-tensorboard
und dem optionalen Flag--experiment-name
ausführen. Eine vollständige Liste der Schritte finden Sie unter Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen.
Die Protokolle enthalten einen Link zu einem Vertex AI TensorBoard, ähnlich wie hier:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
Sie finden den Link zu Vertex AI TensorBoard auch in der Google Cloud Console. Rufen Sie in der Google Cloud Console Vertex AI-Tests auf. Wählen Sie im Drop-down-Menü die gewünschte Region aus.
Das TensorBoard-Verzeichnis wird auch in den Cloud Storage-Bucket geschrieben, den Sie mit ${BASE_OUTPUT_DIR}
angegeben haben.
XPK-Arbeitslasten löschen
Verwenden Sie den Befehl xpk workload delete
, um eine oder mehrere Arbeitslasten basierend auf dem Jobpräfix oder dem Jobstatus zu löschen. Dieser Befehl kann nützlich sein, wenn Sie XPK-Arbeitslasten gesendet haben, die nicht mehr ausgeführt werden müssen, oder wenn Jobs in der Warteschlange hängen.
XPK-Cluster löschen
Verwenden Sie den Befehl xpk cluster delete
, um einen Cluster zu löschen:
python3 xpk.py cluster delete --cluster CLUSTER_NAME --zone $ZONE --project $PROJECT_ID
Llama und PyTorch
In dieser Anleitung wird beschrieben, wie Sie Llama-Modelle mit PyTorch/XLA auf einer TPU v6e mit dem Dataset WikiText trainieren. Außerdem können Nutzer hier auf PyTorch-TPU-Modellbeschreibungen als Docker-Images zugreifen.
Installation
Installieren Sie den pytorch-tpu/transformers
-Fork von Hugging Face Transformers und die Abhängigkeiten in einer virtuellen Umgebung:
git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate
Modellkonfigurationen einrichten
Der Trainingsbefehl im nächsten Abschnitt Modellscript erstellen verwendet zwei JSON-Konfigurationsdateien, um Modellparameter und die FSDP-Konfiguration (Fully Sharded Data Parallel) zu definieren. Das FSDP-Sharding wird verwendet, damit die Modellgewichte während des Trainings zu einer größeren Batchgröße passen. Beim Training mit kleineren Modellen reicht es möglicherweise aus, nur Datenparallelität zu verwenden und die Gewichte auf jedem Gerät zu replizieren. Weitere Informationen zum Aufteilen von Tensoren auf Geräte in PyTorch/XLA finden Sie im PyTorch/XLA SPMD-Nutzerhandbuch.
Erstellen Sie die Konfigurationsdatei für die Modellparameter. Im Folgenden finden Sie die Modellparameterkonfiguration für Llama3-8B. Für andere Modelle finden Sie die Konfiguration auf Hugging Face. Siehe beispielsweise die Llama2-7B-Konfiguration.
{ "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 }
Erstellen Sie die FSDP-Konfigurationsdatei:
{ "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true }
Weitere Informationen zu FSDP finden Sie unter FSDPv2.
Laden Sie die Konfigurationsdateien mit dem folgenden Befehl auf Ihre TPU-VMs hoch:
gcloud alpha compute tpus tpu-vm scp YOUR_CONFIG_FILE.json $TPU_NAME:. \ --worker=all \ --project=$PROJECT \ --zone $ZONE
Sie können die Konfigurationsdateien auch in Ihrem aktuellen Arbeitsverzeichnis erstellen und das Flag
--base-docker-image
in XPK verwenden.
Modellscript erstellen
Erstellen Sie Ihr Modellscript und geben Sie die Konfigurationsdatei für die Modellparameter mit dem Flag --config_name
und die FSDP-Konfigurationsdatei mit dem Flag --fsdp_config
an.
Sie führen dieses Script im nächsten Abschnitt Modell ausführen auf Ihrer TPU aus. Führen Sie das Modellskript noch nicht aus.
PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 PROFILE_EPOCH=0 PROFILE_STEP=3 PROFILE_DURATION_MS=100000 PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 8 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/config-8B.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp_config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20
Modell ausführen
Führen Sie das Modell mit dem im vorherigen Schritt erstellten Script aus: Modellscript erstellen.
Wenn Sie eine TPU-VM mit einem einzelnen Host verwenden (z. B. v6e-4
), können Sie den Trainingsbefehl direkt auf der TPU-VM ausführen. Wenn Sie eine TPU-VM mit mehreren Hosts verwenden, können Sie das Script mit dem folgenden Befehl gleichzeitig auf allen Hosts ausführen:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=YOUR_COMMAND
Fehlerbehebung bei PyTorch/XLA
Wenn Sie die optionalen Variablen für das Debuggen im vorherigen Abschnitt festgelegt haben, wird das Profil für das Modell an dem Speicherort gespeichert, der in der Variablen PROFILE_LOGDIR
angegeben ist. Sie können die dort gespeicherte xplane.pb
-Datei extrahieren und mit tensorboard
die Profile in Ihrem Browser anzeigen lassen. Folgen Sie dazu der TensorBoard-Anleitung. Wenn PyTorch/XLA nicht wie erwartet funktioniert, lesen Sie den Leitfaden zur Fehlerbehebung. Dort finden Sie Vorschläge zum Debuggen, Profilieren und Optimieren Ihrer Modelle.
DLRM DCN v2-Anleitung
In dieser Anleitung wird beschrieben, wie Sie das DLRM DCN v2-Modell auf einer TPU v6e trainieren.
Wenn Sie mehrere Hosts verwenden, setzen Sie tpu-runtime
mit der entsprechenden TensorFlow-Version zurück. Führen Sie dazu den folgenden Befehl aus. Wenn Sie die Anwendung auf einem einzelnen Host ausführen, müssen Sie die folgenden beiden Befehle nicht ausführen.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID}
--zone ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
SSH-Verbindung zu worker-0 herstellen
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}
TPU-Namen festlegen
export TPU_NAME=${TPU_NAME}
DLRM v2 ausführen
pip install cloud-tpu-client
pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'
TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py --mode=train --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'mixed_bfloat16'
task:
use_synthetic_data: false
use_tf_record_reader: true
train_data:
input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
global_batch_size: 16384
use_cached_data: true
validation_data:
input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
global_batch_size: 16384
use_cached_data: true
model:
num_dense_features: 13
bottom_mlp: [512, 256, 128]
embedding_dim: 128
interaction: 'multi_layer_dcn'
dcn_num_layers: 3
dcn_low_rank_dim: 512
size_threshold: 8000
top_mlp: [1024, 1024, 512, 256, 1]
use_multi_hot: true
concat_dense: false
dcn_use_bias: true
vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
max_ids_per_chip_per_sample: 128
max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
use_partial_tpu_embedding: false
size_threshold: 0
initialize_tables_on_host: true
trainer:
train_steps: 10000
validation_interval: 1000
validation_steps: 660
summary_interval: 1000
steps_per_loop: 1000
checkpoint_interval: 0
optimizer_config:
embedding_optimizer: 'Adagrad'
dense_optimizer: 'Adagrad'
lr_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.025
warmup_steps: 0
dense_sgd_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.00025
warmup_steps: 8000
train_tf_function: true
train_tf_while_loop: true
eval_tf_while_loop: true
use_orbit: true
pipeline_sparse_and_dense_execution: true"
Führen Sie script.sh
aus.
chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Die folgenden Flags sind erforderlich, um Empfehlungsarbeitslasten (DLRM DCN) auszuführen:
ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"
Benchmarking-Ergebnisse
Der folgende Abschnitt enthält Benchmarking-Ergebnisse für DLRM DCN v2 und MaxDiffusion auf v6e.
DLRM DCN v2
Das DLRM DCN v2-Trainingsskript wurde in verschiedenen Skalen ausgeführt. Die Durchlaufraten finden Sie in der folgenden Tabelle.
v6e-64 | v6e-128 | v6e-256 | |
Trainingsschritte | 7000 | 7000 | 7000 |
Globale Batchgröße | 131.072 | 262144 | 524.288 |
Durchsatz (Beispiele/Sek.) | 2975334 | 5111808 | 10066329 |
MaxDiffusion
Wir haben das Trainingsskript für MaxDiffusion auf einem v6e-4, einem v6e-16 und einem 2xv6e-16 ausgeführt. Die Durchlaufraten finden Sie in der folgenden Tabelle.
v6e-4 | v6e-16 | Zwei v6e-16 | |
Trainingsschritte | 0.069 | 0,073 | 0,13 |
Globale Batchgröße | 8 | 32 | 64 |
Durchsatz (Beispiele/Sek.) | 115,9 | 438,4 | 492,3 |
Sammlungen
In Version 6e wird die neue Funktion „Sammlungen“ für Nutzer eingeführt, die Bereitstellungslasten ausführen. Die Sammlungsfunktion gilt nur für Version 6e.
Mit Sammlungen können Sie Google Cloud mitteilen, welche Ihrer TPU-Knoten zu einer Bereitstellungsarbeitslast gehören. So kann die zugrunde liegende Google Cloud-Infrastruktur Unterbrechungen, die im normalen Betrieb auf Trainingslasten angewendet werden können, begrenzen und optimieren.
Sammlungen aus der Cloud TPU API verwenden
Eine Sammlung mit einem einzelnen Host in der Cloud TPU API ist eine Ressourcenwarteschlange, für die ein spezielles Flag (--workload-type = availability-optimized
) festgelegt ist, um der zugrunde liegenden Infrastruktur anzugeben, dass sie für die Bereitstellung von Arbeitslasten verwendet werden soll.
Mit dem folgenden Befehl wird eine Sammlung mit einem einzelnen Host mithilfe der Cloud TPU API bereitgestellt:
gcloud alpha compute tpus queued-resources create COLLECTION_NAME \ --project=project name \ --zone=zone name \ --accelerator-type=accelerator type \ --node-count=number of nodes \ --workload-type=availability-optimized
Überwachen und profilieren
Cloud TPU v6e unterstützt Monitoring und Profiling mit denselben Methoden wie frühere Cloud TPU-Generationen. Weitere Informationen zum Monitoring finden Sie unter TPU-VMs überwachen.