Trillium (v6e) – Einführung
In dieser Dokumentation, in der TPU API und in den Protokollen wird „v6e“ für Trillium verwendet. „v6e“ steht für die sechste Generation von TPU von Google.
Mit 256 Chips pro Pod hat die v6e-Architektur viele Ähnlichkeiten mit v5e. Dieses System ist für das Training, die Feinabstimmung und die Bereitstellung von Transformern, Text-zu-Bild-Modellen und CNNs (Convolutional Neural Networks) optimiert.
Informationen zur Systemarchitektur und zu den Konfigurationen von v6e finden Sie im Dokument v6e.
In diesem Einführungsdokument liegt der Schwerpunkt auf den Prozessen für das Modelltraining und die Bereitstellung mit den Frameworks JAX, PyTorch oder TensorFlow. Mit jedem Framework können Sie TPUs mithilfe von Ressourcen in der Warteschlange oder der Google Kubernetes Engine (GKE) bereitstellen. Die GKE-Einrichtung kann mit XPK- oder GKE-Befehlen erfolgen.
Allgemeines Verfahren zum Trainieren oder Bereitstellen eines Modells mit v6e
- Google Cloud Projekt vorbereiten
- Sichere Kapazität
- TPU-Umgebung einrichten
- Cloud TPU-Umgebung bereitstellen
- Arbeitslast für Modelltraining oder Inferenz ausführen
- Bereinigen
Google Cloud Projekt vorbereiten
- Melden Sie sich in Ihrem Google-Konto an. Wenn Sie noch kein Konto haben, melden Sie sich hier für ein neues Konto an.
- Wählen Sie in der Google Cloud Console auf der Seite der Projektauswahl ein Cloud-Projekt aus oder erstellen Sie eines.
- Aktivieren Sie die Abrechnung für Ihr Google Cloud-Projekt. Für die gesamte Nutzung von Google Cloud ist eine Abrechnung erforderlich.
- Installieren Sie die gcloud alpha-Komponenten.
Führen Sie den folgenden Befehl aus, um die neueste Version der
gcloud
-Komponenten zu installieren.gcloud components update
Aktivieren Sie die TPU API mit dem folgenden
gcloud
-Befehl in Cloud Shell. Sie können sie auch über die Google Cloud Console aktivieren.gcloud services enable tpu.googleapis.com
Berechtigungen mit dem TPU-Dienstkonto für die Compute Engine API aktivieren
Dienstkonten ermöglichen dem Cloud TPU-Dienst, auf andere Google Cloud-Dienste zuzugreifen. Ein nutzerverwaltetes Dienstkonto ist eine empfohlene Google Cloud-Methode. Folgen Sie diesen Anleitungen, um Rollen zu erstellen und zu gewähren. Folgende Rollen sind erforderlich:
- TPU-Administrator
- Storage-Administrator
- Log-Autor
- Monitoring-Messwert-Autor
a. Richten Sie XPK-Berechtigungen mit Ihrem Nutzerkonto für GKE ein: XPK.
Authentifizieren Sie sich mit Ihrem Google-Konto und legen Sie die Standardprojekt-ID und -Zone fest.
auth login
autorisiert gcloud, mit Google-Nutzeranmeldedaten auf Google Cloud zuzugreifen.
PROJECT_ID
ist der Google Cloud Projektname.
ZONE
ist die Zone, in der Sie die TPU erstellen möchten.gcloud auth login gcloud config set project ${PROJECT_ID} gcloud config set compute/zone ${ZONE}
Erstellen Sie eine Dienstidentität für die TPU-VM.
gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
Sichere Kapazität
Wenden Sie sich an Ihren Cloud TPU-Supportmitarbeiter, um ein TPU-Kontingent anzufordern und Fragen zur Kapazität zu stellen.
Cloud TPU-Umgebung bereitstellen
v6e-TPUs können mit GKE, mit GKE und XPK (einem Befehlszeilen-Wrapper für GKE) oder als Ressourcen in der Warteschlange bereitgestellt und verwaltet werden.
Vorbereitung
- Prüfen Sie, ob Ihr Projekt ein ausreichendes
TPUS_PER_TPU_FAMILY
-Kontingent hat. Dieses gibt die maximale Anzahl von Chips an, auf die Sie in IhremGoogle Cloud -Projekt zugreifen können. - v6e wurde mit der folgenden Konfiguration getestet:
- Python
3.10
oder höher - Nightly-Softwareversionen:
- tägliche JAX
0.4.32.dev20240912
- nightly LibTPU
0.1.dev20240912+nightly
- tägliche JAX
- Stabile Softwareversionen:
- JAX + JAX-Bibliothek der Version 0.4.37
- Python
Prüfen Sie, ob Ihr Projekt ein ausreichendes TPU-Kontingent für Folgendes hat:
- TPU-VM-Kontingent
- Kontingent für IP-Adressen
Hyperdisk-Balance-Kontingent
Nutzerprojektberechtigungen
- Wenn Sie GKE mit XPK verwenden, finden Sie unter Cloud Console-Berechtigungen für das Nutzer- oder Dienstkonto Informationen zu den Berechtigungen, die zum Ausführen von XPK erforderlich sind.
Umgebungsvariablen
Erstellen Sie in Cloud Shell die folgenden Umgebungsvariablen:
export NODE_ID=TPU_NODE_ID # TPU name export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=v6e-16 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID export VALID_DURATION=VALID_DURATION # Additional environment variable needed for provisioning Multislice: export NUM_SLICES=NUM_SLICES # Use a custom network for better performance as well as to avoid having the default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Beschreibung der Befehls-Flags
Variable | Beschreibung |
NODE_ID | Die vom Nutzer zugewiesene ID der TPU, die erstellt wird, wenn die anstehende Ressourcenanfrage zugewiesen wird. |
PROJECT_ID | Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter |
ZONE | Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen. |
ACCELERATOR_TYPE | Weitere Informationen finden Sie unter Beschleunigertypen. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Das ist die E-Mail-Adresse Ihres Dienstkontos. Sie finden sie in der Google Cloud Console unter „IAM“ -> „Dienstkonten“.
Beispiel: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com |
NUM_SLICES | Die Anzahl der zu erstellenden Schichten (nur für Mehrfachaufnahmen erforderlich). |
QUEUED_RESOURCE_ID | Die vom Nutzer zugewiesene Text-ID der anstehenden Ressourcenanfrage. |
VALID_DURATION | Die Dauer, für die die angeforderte Ressource gültig ist. |
NETWORK_NAME | Der Name eines sekundären Netzwerks, das verwendet werden soll. |
NETWORK_FW_NAME | Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll. |
Optimierungen der Netzwerkleistung
Für die beste Leistung verwenden Sie ein Netzwerk mit 8.896 MTU (maximale Übertragungseinheit).
Standardmäßig bietet eine Virtual Private Cloud (VPC) nur eine MTU von 1.460 Byte, was zu einer suboptimalen Netzwerkleistung führt. Sie können die MTU eines VPC-Netzwerk auf einen beliebigen Wert zwischen 1.300 Byte und 8.896 Byte (einschließlich) festlegen. Gängige benutzerdefinierte MTU-Größen sind 1.500 Byte (Ethernet-Standard) oder 8.896 Byte (maximal möglich). Weitere Informationen finden Sie unter Gültige MTU-VPC-Netzwerk-Netzwerke.
Weitere Informationen zum Ändern der MTU-Einstellung für ein vorhandenes oder Standardnetzwerk finden Sie unter MTU-Einstellung eines VPC-Netzwerks ändern.
Im folgenden Beispiel wird ein Netzwerk mit 8.896 MTU erstellt.
export RESOURCE_NAME=RESOURCE_NAME export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall export PROJECT=X gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT}
Mehrere NICs verwenden (Option für Multi-Slice)
Die folgenden Umgebungsvariablen sind für ein sekundäres Subnetz erforderlich, wenn Sie eine Multislice-Umgebung verwenden.
export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2
Verwenden Sie die folgenden Befehle, um eine benutzerdefinierte IP-Weiterleitung für das Netzwerk und das Subnetz zu erstellen.
gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
--bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
--network="${NETWORK_NAME_2}" \
--range=10.10.0.0/18 --region="${REGION}" \
--project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
--network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
gcloud compute routers create "${ROUTER_NAME}" \
--project="${PROJECT_ID}" \
--network="${NETWORK_NAME_2}" \
--region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
--router="${ROUTER_NAME}" \
--region="${REGION}" \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project="${PROJECT_ID}" \
--enable-logging
Nachdem ein Multi-Network Slice erstellt wurde, können Sie prüfen, ob beide NICs verwendet werden. Dazu richten Sie einen XPK-Cluster ein und führen --command ifconfig
als Teil der XPK-Arbeitslast aus.
Verwenden Sie den folgenden xpk workload
-Befehl, um die Ausgabe des Befehls ifconfig
in den Cloud Console-Protokollen anzuzeigen, und prüfen Sie, ob sowohl eth0 als auch eth1 die mtu=8896 haben.
python3 xpk.py workload create \ --cluster CLUSTER_NAME \ (--base-docker-image maxtext_base_image|--docker-image CLOUD_IMAGE_NAME \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command "ifconfig"
Prüfen Sie, ob sowohl eth0 als auch eth1 die mtu=8.896 haben. Sie können prüfen, ob Multi-NIC verwendet wird, indem Sie den Befehl „–command ifconfig“ als Teil der XPK-Arbeitslast ausführen. Sehen Sie sich dann die gedruckte Ausgabe dieser xpk-Arbeitslast in den Cloud Console-Protokollen an und prüfen Sie, ob sowohl eth0 als auch eth1 die mtu=8896 haben.
Verbesserte TCP-Einstellungen
Bei TPUs, die über die Benutzeroberfläche für Ressourcen in der Warteschlange erstellt wurden, können Sie den folgenden Befehl ausführen, um die Netzwerkleistung zu verbessern, indem Sie die TCP-Empfangsbufferlimits erhöhen.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "$PROJECT" \ --zone "$ZONE" \ --node=all \ --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \ --worker=all
Bereitstellung mit in die Warteschlange gestellten Ressourcen
Die zugewiesene Kapazität kann mit dem Befehl create
für Ressourcen in der Warteschlange bereitgestellt werden.
Erstellen Sie eine Anfrage für eine in die Warteschlange gestellte TPU-Ressource.
Das
--reserved
-Flag ist nur für reservierte Ressourcen erforderlich, nicht für On-Demand-Ressourcen.gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ [--reserved] # The following flags are only needed if you are using Multislice. --node-count node-count # Number of slices in a Multislice \ --node-prefix node-prefix # An optional user-defined node prefix; the default is QUEUED_RESOURCE_ID.
Wenn die Anfrage für die Ressourcenwarteschlange erfolgreich erstellt wurde, hat das Feld „response“ entweder den Status „WAITING_FOR_RESOURCES“ oder „FAILED“. Wenn die in der Warteschlange befindliche Ressourcenanfrage den Status „WAITING_FOR_RESOURCES“ (Warten auf Ressourcen) hat, wurde die Ressource der Warteschlange hinzugefügt und wird bereitgestellt, sobald genügend zugewiesene TPU-Kapazität vorhanden ist. Wenn die Anfrage für die Ressourcenwarteschlange den Status „FAILED“ (Fehlgeschlagen) hat, wird der Grund für den Fehler in der Ausgabe angezeigt. Die anstehende Ressourcenanfrage läuft ab, wenn innerhalb der angegebenen Dauer kein v6e bereitgestellt wird. Der Status ändert sich dann in „FAILED“. Weitere Informationen finden Sie in der öffentlichen Dokumentation zu Ressourcen in der Warteschlange.
Wenn Ihre in die Warteschlange gestellte Ressourcenanfrage den Status „AKTIV“ hat, können Sie über SSH eine Verbindung zu Ihren TPU-VMs herstellen. Verwenden Sie die Befehle
list
oderdescribe
, um den Status der in der Warteschlange befindlichen Ressource abzufragen.gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE}
Wenn sich die erwartete Ressource im Status „AKTIV“ befindet, sieht die Ausgabe in etwa so aus:
state: state: ACTIVE
TPU-VMs verwalten Informationen zu den Optionen zum Verwalten Ihrer TPU-VMs finden Sie unter TPU-VMs verwalten.
Über SSH eine Verbindung zu TPU-VMs herstellen
Sie können Binärdateien auf jeder TPU-VM in Ihrem TPU-Speicherplatz installieren und Code ausführen. Im Abschnitt VM-Typen erfahren Sie, wie viele VMs Ihr Slice haben wird.
Wenn Sie die Binärdateien installieren oder Code ausführen möchten, können Sie über SSH mit dem Befehl
tpu-vm ssh
eine Verbindung zu einer VM herstellen.gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --node=all # add this flag if you are using Multislice
Wenn Sie über SSH eine Verbindung zu einer bestimmten VM herstellen möchten, verwenden Sie das Flag
--worker
, das einem Index mit der Basis 0 folgt:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
Wenn Sie Slice-Formen mit mehr als 8 Chips haben, befinden sich mehrere VMs in einem Slice. Verwenden Sie in diesem Fall die Parameter
--worker=all
und--command
in Ihremgcloud alpha compute tpus tpu-vm ssh
-Befehl, um einen Befehl gleichzeitig auf allen VMs auszuführen. Beispiel:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \ --zone ${ZONE} --worker=all \ --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
In die Warteschlange gestellte Ressource löschen
Lösche am Ende der Sitzung eine in die Warteschlange gestellte Ressource oder entferne in die Warteschlange gestellte Ressourcenanfragen, die den Status „FAILED“ (Fehlgeschlagen) haben. Wenn Sie eine in die Warteschlange gestellte Ressource löschen möchten, löschen Sie in zwei Schritten den Ausschnitt und dann die Anfrage für die in die Warteschlange gestellte Ressource:
gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \ --zone=${ZONE} --quiet gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet
gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
TPUs der Version 6e mit GKE oder XPK bereitstellen
Wenn Sie GKE-Befehle mit v6e verwenden, können Sie Kubernetes-Befehle oder XPK verwenden, um TPUs bereitzustellen und Modelle zu trainieren oder bereitzustellen. Unter TPUs in GKE planen erfahren Sie, wie Sie Ihre TPU-Konfigurationen in GKE-Clustern planen. In den folgenden Abschnitten finden Sie Befehle zum Erstellen eines XPK-Clusters mit Unterstützung für eine einzelne NIC und mehrere NICs.
Befehle zum Erstellen eines XPK-Clusters mit Unterstützung für eine einzelne NIC
export CLUSTER_NAME xpk-cluster-name export ZONE=us-central2-b export PROJECT=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network ${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster $CLUSTER_NAME \ --cluster-cpu-machine-type=n1-standard-8 \ --num-slices=$NUM_SLICES \ --tpu-type=$TPU_TYPE \ --zone=$ZONE \ --project=$PROJECT \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
Beschreibung der Befehls-Flags
Variable | Beschreibung |
CLUSTER_NAME | Der vom Nutzer zugewiesene Name für den XPK-Cluster. |
PROJECT_ID | Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter |
ZONE | Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen. |
TPU_TYPE | Weitere Informationen finden Sie unter Beschleunigertypen. |
NUM_SLICES | Die Anzahl der Scheiben, die Sie erstellen möchten |
CLUSTER_ARGUMENTS | Das zu verwendende Netzwerk und Subnetzwerk.
Beispiel: „--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}“ |
NUM_SLICES | Die Anzahl der zu erstellenden Segmente. |
NETWORK_NAME | Der Name eines sekundären Netzwerks, das verwendet werden soll. |
NETWORK_FW_NAME | Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll. |
Befehle zum Erstellen eines XPK-Clusters mit Unterstützung mehrerer NICs
export CLUSTER_NAME xpk-cluster-name export ZONE=us-central2-b export PROJECT=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export exportSUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create "${NETWORK_NAME_1}" \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_1}" \ --network="${NETWORK_NAME_1}" \ --range=10.11.0.0/18 \ --region="${REGION}" \ --project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \ --network "${NETWORK_NAME_1}" \ --allow tcp,icmp,udp \ --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \ --project="${PROJECT}" \ --network="${NETWORK_NAME_1}" \ --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \ --router="${ROUTER_NAME}" \ --region="${REGION}" \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project="${PROJECT}" \ --enable-logging
Secondary subnet for multi-nic experience. Need custom ip routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create "${NETWORK_NAME_2}" \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \ --network="${NETWORK_NAME_2}" \ --range=10.10.0.0/18 \ --region="${REGION}" \ --project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \ --network "${NETWORK_NAME_2}" \ --allow tcp,icmp,udp \ --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \ --project="${PROJECT}" \ --network="${NETWORK_NAME_2}" \ --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \ --router="${ROUTER_NAME}" \ --region="${REGION}" \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project="${PROJECT}" \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster $CLUSTER_NAME \
--num-slices=$NUM_SLICES \
--tpu-type=$TPU_TYPE \
--zone=$ZONE \
--project=$PROJECT \
--on-demand \
--custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
--custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
--create-vertex-tensorboard
Beschreibung der Befehls-Flags
Variable | Beschreibung |
CLUSTER_NAME | Der vom Nutzer zugewiesene Name für den XPK-Cluster. |
PROJECT_ID | Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter |
ZONE | Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen. |
TPU_TYPE | Weitere Informationen finden Sie unter Beschleunigertypen. |
NUM_SLICES | Die Anzahl der Scheiben, die Sie erstellen möchten |
CLUSTER_ARGUMENTS | Das zu verwendende Netzwerk und Subnetzwerk.
Beispiel: „--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}“ |
NODE_POOL_ARGUMENTS | Zu verwendendes zusätzliches Knotennetzwerk.
Beispiel: „--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}“ |
NUM_SLICES | Die Anzahl der zu erstellenden Schichten (nur für Mehrfachaufnahmen erforderlich). |
NETWORK_NAME | Der Name eines sekundären Netzwerks, das verwendet werden soll. |
NETWORK_FW_NAME | Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll. |
Framework einrichten
In diesem Abschnitt wird die allgemeine Einrichtung für das Training von ML-Modellen mit den Frameworks JAX, PyTorch oder TensorFlow beschrieben. Sie können TPUs mithilfe von Ressourcen in der Warteschlange oder GKE bereitstellen. Die GKE-Einrichtung kann mit XPK- oder Kubernetes-Befehlen erfolgen.
Einrichtung für JAX
In diesem Abschnitt finden Sie Beispiele für die Ausführung von JAX-Arbeitslasten in GKE mit oder ohne XPK sowie für die Verwendung von Ressourcen in der Warteschlange.
JAX mit GKE einrichten
Im folgenden Beispiel wird ein einzelner 2 × 2-Host mit einer Kubernetes-YAML-Datei eingerichtet.
Einzelnes Slice auf einem einzelnen Host
apiVersion: v1
kind: Pod
metadata:
name: tpu-pod-jax-v6e-a
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
containers:
- name: tpu-job
image: python:3.10
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Nach dem erfolgreichen Abschluss sollte im GKE-Log die folgende Meldung angezeigt werden:
Total TPU chips: 4
Einzelnes Slice auf mehreren Hosts
Im folgenden Beispiel wird ein 4 × 4-Knotenpool mit mehreren Hosts mithilfe einer Kubernetes-YAML-Datei eingerichtet.
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-available-chips
spec:
backoffLimit: 0
completions: 4
parallelism: 4
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: tpu-job
image: python:3.10
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Nach dem erfolgreichen Abschluss sollte im GKE-Log die folgende Meldung angezeigt werden:
Total TPU chips: 16
Multislice auf mehreren Hosts
Im folgenden Beispiel werden zwei 4 × 4-Multihost-Knotenpools mit einer Kubernetes-YAML-Datei eingerichtet.
Als Voraussetzung müssen Sie JobSet v0.2.3 oder höher installieren.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: multislice-job
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
maxRestarts: 4
replicatedJobs:
- name: slice
replicas: 2
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
hostNetwork: true
containers:
- name: jax-tpu
image: python:3.10
ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 8431
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
limits:
google.com/tpu: 4
requests:
google.com/tpu: 4
Nach dem erfolgreichen Abschluss sollte im GKE-Log die folgende Meldung angezeigt werden:
Total TPU chips: 32
Weitere Informationen finden Sie in der GKE-Dokumentation unter Multi-Slice-Arbeitslast ausführen.
Aktivieren Sie für eine bessere Leistung hostNetwork.
Mehrere NICs
Damit Sie in GKE mehrere NICs nutzen können, muss das Kubernetes-Pod-Manifest zusätzliche Anmerkungen enthalten. Im Folgenden finden Sie ein Beispielmanifest für eine Arbeitslast mit mehreren NICs ohne TPU.
apiVersion: v1
kind: Pod
metadata:
name: sample-netdevice-pod-1
annotations:
networking.gke.io/default-interface: 'eth0'
networking.gke.io/interfaces: |
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"netdevice-network"}
]
spec:
containers:
- name: sample-netdevice-pod
image: busybox
command: ["sleep", "infinity"]
ports:
- containerPort: 80
restartPolicy: Always
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
Wenn Sie exec
in den Kubernetes-Pod eingeben, sollte die zusätzliche NIC mit dem folgenden Code angezeigt werden.
$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
inet 172.24.0.4/32 scope global eth1
valid_lft forever preferred_lft forever
JAX mit GKE und XPK einrichten
Ein Beispiel finden Sie in der xpk-README-Datei.
Informationen zum Einrichten und Ausführen von XPK mit MaxText finden Sie unter MaxText ausführen.
JAX mit Ressourcen in der Warteschlange einrichten
Mit gcloud alpha compute tpus tpu-vm ssh
können Sie JAX gleichzeitig auf allen TPU-VMs in Ihrem Slice oder Ihren Slices installieren. Fügen Sie für Multislice --node=all
hinzu.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'
Sie können den folgenden Python-Code ausführen, um zu prüfen, wie viele TPU-Kerne in Ihrem Slice verfügbar sind, und um zu testen, ob alles richtig installiert ist. Die hier gezeigten Ausgaben wurden mit einem v6e-16-Speicher erstellt:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
Die Ausgabe sieht in etwa so aus:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count() gibt die Gesamtzahl der Chips im angegebenen Slice an. jax.local_device_count() gibt die Anzahl der Chips an, auf die eine einzelne VM in diesem Slice zugreifen kann.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
&& pip install -r requirements.txt && pip install . '
Fehlerbehebung bei JAX-Einrichtungen
Als allgemeinen Tipp können wir Ihnen empfehlen, im Manifest Ihrer GKE-Arbeitslast ausführliches Logging zu aktivieren. Reichen Sie die Protokolle dann beim GKE-Support ein.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Fehlermeldungen
no endpoints available for service 'jobset-webhook-service'
Dieser Fehler bedeutet, dass der Jobsatz nicht richtig installiert wurde. Prüfen Sie, ob die Kubernetes-Pods der Bereitstellung „jobset-controller-manager“ ausgeführt werden. Weitere Informationen finden Sie in der Dokumentation zur Fehlerbehebung bei JobSets.
TPU initialization failed: Failed to connect
Die GKE-Knotenversion muss 1.30.4-gke.1348000 oder höher sein. GKE 1.31 wird nicht unterstützt.
Einrichtung für PyTorch
In diesem Abschnitt wird beschrieben, wie Sie PJRT unter v6e mit PyTorch/XLA verwenden. Python 3.10 ist die empfohlene Python-Version.
PyTorch mit GKE und XPK einrichten
Sie können den folgenden Docker-Container mit XPK verwenden, in dem die PyTorch-Abhängigkeiten bereits installiert sind:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Führen Sie den folgenden Befehl aus, um eine XPK-Arbeitslast zu erstellen:
python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES} \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
Mit --base-docker-image
wird ein neues Docker-Image erstellt, in das das aktuelle Arbeitsverzeichnis eingebunden ist.
PyTorch mit in die Warteschlange gestellten Ressourcen einrichten
Führen Sie die folgenden Schritte aus, um PyTorch mit Ressourcen in der Warteschlange zu installieren und ein kleines Script auf v6e auszuführen.
Abhängigkeiten über SSH installieren, um auf die VMs zuzugreifen
Fügen Sie für Multislice --node=all
hinzu:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt install -y libopenblas-base pip3 \
install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Leistung von Modellen mit großen, häufigen Zuweisungen verbessern
Bei Modellen mit großen, häufigen Zuweisungen haben wir festgestellt, dass die Verwendung von tcmalloc
die Leistung im Vergleich zur Standardimplementierung von malloc
erheblich verbessert. Daher ist tcmalloc
die Standard-malloc
auf der TPU-VM. Je nach Arbeitslast kann tcmalloc
jedoch zu einer Verlangsamung führen, z. B. bei DLRM mit sehr großen Zuweisungen für die Einbettungstabellen. In diesem Fall können Sie versuchen, die folgende Variable zurückzusetzen und stattdessen die Standardeinstellung malloc
zu verwenden:
unset LD_PRELOAD
Mit einem Python-Script eine Berechnung auf einer v6e-VM ausführen:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--project ${PROJECT_ID} \
--zone ${ZONE} --worker all --command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'
Dadurch wird eine Ausgabe generiert, die etwa so aussieht:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
Einrichtung für TensorFlow
Für die öffentliche Vorschau von v6e wird nur die Laufzeitversion „tf-nightly“ unterstützt.
Sie können tpu-runtime
mit der mit v6e kompatiblen TensorFlow-Version zurücksetzen, indem Sie die folgenden Befehle ausführen:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Verwenden Sie SSH, um auf worker-0 zuzugreifen:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE}
Installieren Sie TensorFlow auf worker-0:
sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Exportieren Sie die Umgebungsvariable TPU_NAME
:
export TPU_NAME=v6e-16
Mit dem folgenden Python-Script können Sie prüfen, wie viele TPU-Kerne in Ihrem Slice verfügbar sind, und testen, ob alles richtig installiert ist. Die angezeigten Ausgaben wurden mit einem v6e-16-Speicherplatz erstellt:
import TensorFlow as tf
print("TensorFlow version " + tf.__version__)
@tf.function
def add_fn(x,y):
z = x + y
return z
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
x = tf.constant(1.)
y = tf.constant(1.)
z = strategy.run(add_fn, args=(x,y))
print(z)
Die Ausgabe sieht in etwa so aus:
PerReplica:{
0: tf.Tensor(2.0, shape=(), dtype=float32),
1: tf.Tensor(2.0, shape=(), dtype=float32),
2: tf.Tensor(2.0, shape=(), dtype=float32),
3: tf.Tensor(2.0, shape=(), dtype=float32),
4: tf.Tensor(2.0, shape=(), dtype=float32),
5: tf.Tensor(2.0, shape=(), dtype=float32),
6: tf.Tensor(2.0, shape=(), dtype=float32),
7: tf.Tensor(2.0, shape=(), dtype=float32)
}
v6e mit SkyPilot
Sie können TPU v6e mit SkyPilot verwenden. Führen Sie die folgenden Schritte aus, um SkyPilot v6e-bezogene Standort-/Preisinformationen hinzuzufügen.
Fügen Sie am Ende von
~/.sky/catalogs/v5/gcp/vms.csv
Folgendes hinzu :,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
Geben Sie die folgenden Ressourcen in einer YAML-Datei an:
# tpu_v6.yaml resources: accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use accelerator_args: runtime_version: v2-alpha-tpuv6e # Official suggested runtime
Cluster mit TPU v6e starten:
sky launch tpu_v6.yaml -c tpu_v6
Stellen Sie eine SSH-Verbindung zur TPU v6e her:
ssh tpu_v6
Anleitungen für Inferenz
In den folgenden Anleitungen wird gezeigt, wie Sie die Inferenz auf einer TPU v6e ausführen:
Trainingsbeispiele
In den folgenden Abschnitten finden Sie Beispiele für das Training von MaxText-, MaxDiffusion- und PyTorch-Modellen auf TPU v6e.
MaxText- und MaxDiffusion-Training auf einer v6e-Cloud TPU-VM
In den folgenden Abschnitten wird der Trainingszyklus der Modelle MaxText und MaxDiffusion beschrieben.
Im Allgemeinen sind dies die allgemeinen Schritte:
- Erstellen Sie das Basis-Image der Arbeitslast.
- Führen Sie Ihre Arbeitslast mit XPK aus.
- Erstellen Sie den Trainingsbefehl für die Arbeitslast.
- Stellen Sie die Arbeitslast bereit.
- Arbeitslast verfolgen und Messwerte ansehen
- Löschen Sie die XPK-Arbeitslast, wenn sie nicht benötigt wird.
- Löschen Sie den XPK-Cluster, wenn er nicht mehr benötigt wird.
Basis-Image erstellen
Installieren Sie MaxText oder MaxDiffusion und erstellen Sie das Docker-Image:
Klonen Sie das gewünschte Repository und wechseln Sie in das Verzeichnis für das Repository:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
Konfigurieren Sie Docker so, dass die Google Cloud CLI verwendet wird:
gcloud auth configure-docker
Erstellen Sie das Docker-Image mit dem folgenden Befehl oder mit JAX Stable Stack. Weitere Informationen zum JAX Stable Stack finden Sie unter Docker-Image mit JAX Stable Stack erstellen.
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.37
Wenn Sie die Arbeitslast von einem Computer aus starten, auf dem das Image nicht lokal erstellt wurde, laden Sie das Image hoch:
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
Docker-Image mit JAX Stable Stack erstellen
Sie können die Docker-Images „MaxText“ und „MaxDiffusion“ mit dem Basis-Image „JAX Stable Stack“ erstellen.
JAX Stable Stack bietet eine einheitliche Umgebung für MaxText und MaxDiffusion, indem JAX mit Kernpaketen wie orbax
, flax
und optax
sowie einer gut qualifizierten libtpu.so kombiniert wird, die TPU-Programm-Dienstprogramme und andere wichtige Tools antreibt. Diese Bibliotheken werden auf Kompatibilität getestet und bieten eine stabile Grundlage für die Erstellung und Ausführung von MaxText und MaxDiffusion. So werden potenzielle Konflikte aufgrund von inkompatiblen Paketversionen vermieden.
Der stabile JAX-Stack enthält eine vollständig veröffentlichte und qualifizierte libtpu.so, die Kernbibliothek, die die Kompilierung, Ausführung und ICI-Netzwerkkonfiguration von TPU-Programmen steuert. Der libtpu-Release ersetzt den bisher von JAX verwendeten Nightly-Build und sorgt mit Qualifikationstests auf PJRT-Ebene in HLO/StableHLO-IRs für eine konsistente Funktionsweise von XLA-Berechnungen auf TPUs.
Wenn Sie das Docker-Image für MaxText und MaxDiffusion mit dem JAX Stable Stack erstellen möchten, legen Sie beim Ausführen des docker_build_dependency_image.sh
-Scripts die Variable MODE
auf stable_stack
und die Variable BASEIMAGE
auf das gewünschte Basis-Image fest.
Im folgenden Beispiel wird us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
als Basis-Image angegeben:
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
Eine Liste der verfügbaren JAX Stable Stack-Basis-Images finden Sie unter JAX Stable Stack-Images in Artifact Registry.
Arbeitslast mit XPK ausführen
Legen Sie die folgenden Umgebungsvariablen fest, wenn Sie nicht die Standardwerte verwenden, die mit MaxText oder MaxDiffusion festgelegt wurden:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Erstellen Sie das Modellskript. Dieses Script wird in einem späteren Schritt als Trainingsbefehl kopiert.
Führen Sie das Modellskript noch nicht aus.
MaxText
MaxText ist ein leistungsstarkes, hoch skalierbares Open-Source-LLM, das in reiner Python- und JAX-Programmierung geschrieben wurde und auf Google Cloud TPUs und GPUs für Training und Inferenz ausgerichtet ist.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma ist eine Familie von LLMs mit offenen Gewichten, die von Google DeepMind entwickelt wurden und auf der Gemini-Forschung und -Technologie basieren.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral ist ein hochmodernes KI-Modell, das von Mistral AI entwickelt wurde und eine sparse MoE-Architektur (Mixture of Experts) verwendet.
python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama ist eine Familie von LLMs mit offenen Gewichten, die von Meta entwickelt wurden.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=llama3-8b \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ tokenizer_path=assets/tokenizer_llama3.tiktoken \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ attention=flash"
MaxDiffusion
MaxDiffusion ist eine Sammlung von Referenzimplementierungen verschiedener latenter Diffusionsmodelle, die in reiner Python- und JAX-Programmierung geschrieben wurden und auf XLA-Geräten wie Cloud TPUs und GPUs ausgeführt werden. Stable Diffusion ist ein latentes Text-zu-Bild-Modell, das fotorealistische Bilder aus beliebigen Texteingaben generiert.
Sie müssen einen bestimmten Git-Branch installieren, um MaxDiffusion auszuführen, wie im folgenden
git checkout
-Befehl gezeigt.git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 && pip install -r requirements.txt && pip install .
Trainingsskript:
cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \ python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash run_name=sdxl-ddp-v6e
Führen Sie das Modell mit dem im vorherigen Schritt erstellten Script aus. Sie müssen entweder das Flag
--base-docker-image
angeben, um das MaxText-Basisbild zu verwenden, oder das Flag--docker-image
und das gewünschte Bild.Optional: Sie können das Debug-Logging aktivieren, indem Sie das Flag
--enable-debug-logs
einfügen. Weitere Informationen finden Sie unter JAX in MaxText debuggen.Optional: Sie können einen Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen. Fügen Sie dazu das Flag
--use-vertex-tensorboard
hinzu. Weitere Informationen finden Sie unter JAX auf MaxText mit Vertex AI überwachen.python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=$ACCELERATOR_TYPE \ --num-slices=$NUM_SLICES \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command $YOUR-MODEL-SCRIPT
Exportieren Sie die folgenden Variablen:
export ClUSTER_NAME=CLUSTER_NAME: Der Name Ihres XPK-Clusters. export ACCELERATOR_TYPEACCELERATOR_TYPE: Die Version und Größe Ihrer TPU. Beispiel:
v6e-256
export NUM_SLICES=NUM_SLICES: Die Anzahl der TPU-Scheiben. export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: Das Modellscript, das als Trainingsbefehl ausgeführt werden soll.Die Ausgabe enthält einen Link, über den Sie Ihre Arbeitslast verfolgen können, ähnlich dem folgenden:
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
Öffnen Sie den Link und klicken Sie auf den Tab Protokolle, um Ihre Arbeitslast in Echtzeit zu verfolgen.
JAX in MaxText debuggen
Verwenden Sie zusätzliche XPK-Befehle, um zu ermitteln, warum der Cluster oder die Arbeitslast nicht ausgeführt wird.
- XPK-Arbeitslastliste
- XPK-Inspektor
- Aktivieren Sie beim Erstellen der XPK-Arbeitslast mit dem Flag
--enable-debug-logs
ausführliche Protokolle in Ihren Arbeitslastprotokollen.
JAX auf MaxText mit Vertex AI überwachen
Skalar- und Profildaten über das verwaltete TensorBoard von Vertex AI aufrufen
- Erhöhen Sie die Anzahl der Resource Management (CRUD)-Anfragen für die von Ihnen verwendete Zone von 600 auf 5.000. Bei kleinen Arbeitslasten mit weniger als 16 VMs ist das möglicherweise kein Problem.
Installieren Sie Abhängigkeiten wie
cloud-accelerator-diagnostics
für Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Erstellen Sie Ihren XPK-Cluster mit dem Flag
--create-vertex-tensorboard
, wie unter Vertex AI TensorBoard erstellen beschrieben. Sie können diesen Befehl auch auf vorhandenen Clustern ausführen.Erstellen Sie Ihren Vertex AI-Test, wenn Sie Ihre XPK-Arbeitslast mit dem Flag
--use-vertex-tensorboard
und dem optionalen Flag--experiment-name
ausführen. Eine vollständige Liste der Schritte finden Sie unter Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen.
Die Protokolle enthalten einen Link zu einem Vertex AI TensorBoard, ähnlich wie hier:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
Sie finden den Link zu Vertex AI TensorBoard auch in der Google Cloud Console. Rufen Sie in der Google Cloud Console Vertex AI-Tests auf. Wählen Sie im Drop-down-Menü die gewünschte Region aus.
Das TensorBoard-Verzeichnis wird auch in den Cloud Storage-Bucket geschrieben, den Sie mit ${BASE_OUTPUT_DIR}
angegeben haben.
XPK-Arbeitslasten löschen
Verwenden Sie den Befehl xpk workload delete
, um eine oder mehrere Arbeitslasten basierend auf dem Jobpräfix oder Jobstatus zu löschen. Dieser Befehl kann nützlich sein, wenn Sie XPK-Arbeitslasten gesendet haben, die nicht mehr ausgeführt werden müssen, oder wenn Jobs in der Warteschlange hängen.
XPK-Cluster löschen
Verwenden Sie den Befehl xpk cluster delete
, um einen Cluster zu löschen:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone $ZONE --project $PROJECT_ID
Llama- und PyTorch/XLA-Training auf einer v6e-Cloud TPU-VM
In dieser Anleitung wird beschrieben, wie Sie Llama-Modelle mit PyTorch/XLA auf einer TPU v6e mit dem Dataset WikiText trainieren.
Zugriff auf Hugging Face und das Llama 3-Modell erhalten
Sie benötigen ein Nutzerzugriffstoken für Hugging Face, um dieses Tutorial auszuführen. Informationen zum Erstellen und Verwenden von Nutzerzugriffstokens finden Sie in der Hugging Face-Dokumentation zu Nutzerzugriffstokens.
Außerdem benötigen Sie die Berechtigung, auf das Llama 3 8B-Modell auf Hugging Face zuzugreifen. Wenn Sie Zugriff erhalten möchten, rufen Sie das Meta-Llama-3-8B-Modell auf Hugging Face auf und beantragen Sie Zugriff.
TPU-VM erstellen
Erstellen Sie eine TPU v6e mit 8 Chips, um die Anleitung auszuführen.
Richten Sie Umgebungsvariablen ein:
export ACCELERATOR_TYPE=v6e-8 export VERSION=v2-alpha-tpuv6e export TPU_NAME=$USER-$ACCELERATOR_TYPE export PROJECT=YOUR_PROJECT export ZONE=YOUR_ZONE
So erstellen Sie eine TPU-VM:
gcloud alpha compute tpus tpu-vm create $TPU_NAME --version=$VERSION \ --accelerator-type=$ACCELERATOR_TYPE --zone=$ZONE --project=$PROJECT
Installation
Installieren Sie den pytorch-tpu/transformers
-Fork von Hugging Face Transformers und die Abhängigkeiten. Diese Anleitung wurde mit den folgenden Abhängigkeitsversionen getestet, die in diesem Beispiel verwendet werden:
torch
: kompatibel mit 2.5.0torch_xla[tpu]
: kompatibel mit 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT --zone $ZONE \ --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Modellkonfigurationen einrichten
Der Trainingsbefehl im nächsten Abschnitt, Modell ausführen, verwendet zwei JSON-Konfigurationsdateien, um Modellparameter und die FSDP-Konfiguration (Fully Sharded Data Parallel) zu definieren. Das FSDP-Sharding wird für die Modellgewichte verwendet, damit sie während des Trainings zu einer größeren Batchgröße passen. Beim Training mit kleineren Modellen reicht es möglicherweise aus, Datenparallelität zu verwenden und die Gewichte auf jedem Gerät zu replizieren. Weitere Informationen zum Sharding von Tensoren auf Geräten in PyTorch/XLA finden Sie im PyTorch/XLA SPMD-Nutzerhandbuch.
Erstellen Sie die Konfigurationsdatei für die Modellparameter. Im Folgenden finden Sie die Modellparameterkonfiguration für Llama3-8B. Für andere Modelle finden Sie die Konfiguration auf Hugging Face. Siehe beispielsweise die Llama2-7B-Konfiguration.
cat > llama-config.json <
{ "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF Erstellen Sie die FSDP-Konfigurationsdatei:
cat > fsdp-config.json <
{ "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF Weitere Informationen zu FSDP finden Sie unter FSDPv2.
Laden Sie die Konfigurationsdateien mit dem folgenden Befehl auf Ihre TPU-VMs hoch:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \ --worker=all \ --project=$PROJECT \ --zone $ZONE
Modell ausführen
Führen Sie mit den Konfigurationsdateien, die Sie im vorherigen Abschnitt erstellt haben, das run_clm.py
-Script aus, um das Llama 3 8B-Modell mit dem WikiText-Dataset zu trainieren. Das Training dauert auf einer TPU v6e-8 etwa 10 Minuten.
Melden Sie sich mit dem folgenden Befehl auf Ihrer TPU in Hugging Face an:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Modelltraining ausführen:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Fehlerbehebung bei PyTorch/XLA
Wenn Sie die optionalen Variablen für das Debuggen im vorherigen Abschnitt festgelegt haben, wird das Profil für das Modell an dem Speicherort gespeichert, der in der Variablen PROFILE_LOGDIR
angegeben ist. Sie können die dort gespeicherte xplane.pb
-Datei extrahieren und mit tensorboard
die Profile in Ihrem Browser anzeigen lassen. Folgen Sie dazu der TensorBoard-Anleitung. Wenn PyTorch/XLA nicht wie erwartet funktioniert, lesen Sie den Leitfaden zur Fehlerbehebung. Dort finden Sie Vorschläge zum Debuggen, Profilieren und Optimieren Ihres Modells.
DLRM DCN v2-Training auf v6e
In dieser Anleitung wird beschrieben, wie Sie das DLRM DCN v2-Modell auf einer TPU v6e trainieren. Sie müssen eine TPU v6e mit 64, 128 oder 256 Chips bereitstellen.
Wenn Sie mehrere Hosts verwenden, setzen Sie tpu-runtime
mit der entsprechenden TensorFlow-Version zurück. Führen Sie dazu den folgenden Befehl aus. Wenn Sie die Anwendung auf einem einzelnen Host ausführen, müssen Sie die folgenden beiden Befehle nicht ausführen.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID}
--zone ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
SSH-Verbindung zu worker-0 herstellen
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}
TPU-Namen festlegen
export TPU_NAME=${TPU_NAME}
DLRM v2 ausführen
pip install --user setuptools==65.5.0
pip install cloud-tpu-client
pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'
TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py --mode=train --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'mixed_bfloat16'
task:
use_synthetic_data: false
use_tf_record_reader: true
train_data:
input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
global_batch_size: 16384
use_cached_data: true
validation_data:
input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
global_batch_size: 16384
use_cached_data: true
model:
num_dense_features: 13
bottom_mlp: [512, 256, 128]
embedding_dim: 128
interaction: 'multi_layer_dcn'
dcn_num_layers: 3
dcn_low_rank_dim: 512
size_threshold: 8000
top_mlp: [1024, 1024, 512, 256, 1]
use_multi_hot: true
concat_dense: false
dcn_use_bias: true
vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
max_ids_per_chip_per_sample: 128
max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
use_partial_tpu_embedding: false
size_threshold: 0
initialize_tables_on_host: true
trainer:
train_steps: 10000
validation_interval: 1000
validation_steps: 660
summary_interval: 1000
steps_per_loop: 1000
checkpoint_interval: 0
optimizer_config:
embedding_optimizer: 'Adagrad'
dense_optimizer: 'Adagrad'
lr_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.025
warmup_steps: 0
dense_sgd_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.00025
warmup_steps: 8000
train_tf_function: true
train_tf_while_loop: true
eval_tf_while_loop: true
use_orbit: true
pipeline_sparse_and_dense_execution: true"
Führen Sie script.sh
aus.
chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Die folgenden Flags sind erforderlich, um Empfehlungslasten (DLRM DCN) auszuführen:
ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"
Benchmarking-Ergebnisse
Der folgende Abschnitt enthält Benchmarking-Ergebnisse für DLRM DCN v2 und MaxDiffusion auf v6e.
DLRM DCN v2
Das DLRM DCN v2-Trainingsskript wurde in verschiedenen Skalen ausgeführt. Die Durchlaufraten finden Sie in der folgenden Tabelle.
v6e-64 | v6e-128 | v6e-256 | |
Trainingsschritte | 7000 | 7000 | 7000 |
Globale Batchgröße | 131.072 | 262144 | 524.288 |
Durchsatz (Beispiele/Sek.) | 2975334 | 5111808 | 10066329 |
MaxDiffusion
Wir haben das Trainingsskript für MaxDiffusion auf einem v6e-4, einem v6e-16 und einem 2xv6e-16 ausgeführt. Die Durchlaufraten finden Sie in der folgenden Tabelle.
v6e-4 | v6e-16 | Zwei v6e-16 | |
Trainingsschritte | 0.069 | 0,073 | 0,13 |
Globale Batchgröße | 8 | 32 | 64 |
Durchsatz (Beispiele/Sek.) | 115,9 | 438,4 | 492,3 |
Zeitplan für die Datenerhebung
Trillium (Version 6e) enthält eine neue Funktion namens „collection scheduling“ (Sammlungsplanung). Mit dieser Funktion können Sie mehrere TPU-Slices verwalten, auf denen eine -Inferenzarbeitslast mit einem einzelnen Host sowohl in GKE als auch in der Cloud TPU API ausgeführt wird. Wenn Sie diese Scheiben in einer Sammlung gruppieren, lässt sich die Anzahl der Repliken ganz einfach an die Nachfrage anpassen. Softwareupdates werden sorgfältig gesteuert, damit immer ein Teil der Chunks innerhalb der Sammlung für den Umgang mit eingehenden Zugriffen verfügbar ist.
Weitere Informationen zur Verwendung des Sammlungsplanungstools mit GKE finden Sie in der GKE-Dokumentation.
Die Funktion zum Planen der Datenerhebung gilt nur für Version 6e.
Datenerhebung mit der Cloud TPU API planen
Eine Sammlung mit einem einzelnen Host in der Cloud TPU API ist eine Ressourcenwarteschlange, für die ein spezielles Flag (--workload-type = availability-optimized
) festgelegt ist, um der zugrunde liegenden Infrastruktur anzugeben, dass sie für die Bereitstellung von Arbeitslasten verwendet werden soll.
Mit dem folgenden Befehl wird eine Sammlung mit einem einzelnen Host mithilfe der Cloud TPU API bereitgestellt:
gcloud alpha compute tpus queued-resources create my-collection \ --project=$PROJECT_ID \ --zone=${ZONE} \ --accelerator-type $ACCELERATOR_TYPE \ --node-count ${NODE_COUNT} \ --workload-type=availability-optimized
Überwachen und profilieren
Cloud TPU v6e unterstützt Monitoring und Profiling mit denselben Methoden wie frühere Cloud TPU-Generationen. Weitere Informationen zum Monitoring finden Sie unter TPU-VMs überwachen.