Trillium (v6e) – Einführung

In dieser Dokumentation, in der TPU API und in den Protokollen wird „v6e“ für Trillium verwendet. „v6e“ steht für die sechste Generation von TPU von Google.

Mit 256 Chips pro Pod hat die v6e-Architektur viele Ähnlichkeiten mit v5e. Dieses System ist für das Training, die Feinabstimmung und die Bereitstellung von Transformern, Text-zu-Bild-Modellen und CNNs (Convolutional Neural Networks) optimiert.

Informationen zur Systemarchitektur und zu den Konfigurationen von v6e finden Sie im Dokument v6e.

In diesem Einführungsdokument liegt der Schwerpunkt auf den Prozessen für das Modelltraining und die Bereitstellung mit den Frameworks JAX, PyTorch oder TensorFlow. Mit jedem Framework können Sie TPUs mithilfe von Ressourcen in der Warteschlange oder der Google Kubernetes Engine (GKE) bereitstellen. Die GKE-Einrichtung kann mit XPK- oder GKE-Befehlen erfolgen.

Allgemeines Verfahren zum Trainieren oder Bereitstellen eines Modells mit v6e

  1. Google Cloud Projekt vorbereiten
  2. Sichere Kapazität
  3. TPU-Umgebung einrichten
  4. Cloud TPU-Umgebung bereitstellen
  5. Arbeitslast für Modelltraining oder Inferenz ausführen
  6. Bereinigen

Google Cloud Projekt vorbereiten

  1. Melden Sie sich in Ihrem Google-Konto an. Wenn Sie noch kein Konto haben, melden Sie sich hier für ein neues Konto an.
  2. Wählen Sie in der Google Cloud Console auf der Seite der Projektauswahl ein Cloud-Projekt aus oder erstellen Sie eines.
  3. Aktivieren Sie die Abrechnung für Ihr Google Cloud-Projekt. Für die gesamte Nutzung von Google Cloud ist eine Abrechnung erforderlich.
  4. Installieren Sie die gcloud alpha-Komponenten.
  5. Führen Sie den folgenden Befehl aus, um die neueste Version der gcloud-Komponenten zu installieren.

    gcloud components update
    
  6. Aktivieren Sie die TPU API mit dem folgenden gcloud-Befehl in Cloud Shell. Sie können sie auch über die Google Cloud Console aktivieren.

    gcloud services enable tpu.googleapis.com
    
  7. Berechtigungen mit dem TPU-Dienstkonto für die Compute Engine API aktivieren

    Dienstkonten ermöglichen dem Cloud TPU-Dienst, auf andere Google Cloud-Dienste zuzugreifen. Ein nutzerverwaltetes Dienstkonto ist eine empfohlene Google Cloud-Methode. Folgen Sie diesen Anleitungen, um Rollen zu erstellen und zu gewähren. Folgende Rollen sind erforderlich:

    • TPU-Administrator
    • Storage-Administrator
    • Log-Autor
    • Monitoring-Messwert-Autor

    a. Richten Sie XPK-Berechtigungen mit Ihrem Nutzerkonto für GKE ein: XPK.

  8. Authentifizieren Sie sich mit Ihrem Google-Konto und legen Sie die Standardprojekt-ID und -Zone fest.
    auth login autorisiert gcloud, mit Google-Nutzeranmeldedaten auf Google Cloud zuzugreifen.
    PROJECT_ID ist der Google Cloud Projektname.
    ZONE ist die Zone, in der Sie die TPU erstellen möchten.

     gcloud auth login
     gcloud config set project ${PROJECT_ID}
     gcloud config set compute/zone ${ZONE}
    
  9. Erstellen Sie eine Dienstidentität für die TPU-VM.

     gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    

Sichere Kapazität

Wenden Sie sich an Ihren Cloud TPU-Supportmitarbeiter, um ein TPU-Kontingent anzufordern und Fragen zur Kapazität zu stellen.

Cloud TPU-Umgebung bereitstellen

v6e-TPUs können mit GKE, mit GKE und XPK (einem Befehlszeilen-Wrapper für GKE) oder als Ressourcen in der Warteschlange bereitgestellt und verwaltet werden.

Vorbereitung

  • Prüfen Sie, ob Ihr Projekt ein ausreichendes TPUS_PER_TPU_FAMILY-Kontingent hat. Dieses gibt die maximale Anzahl von Chips an, auf die Sie in IhremGoogle Cloud -Projekt zugreifen können.
  • v6e wurde mit der folgenden Konfiguration getestet:
    • Python 3.10 oder höher
    • Nightly-Softwareversionen:
      • tägliche JAX 0.4.32.dev20240912
      • nightly LibTPU 0.1.dev20240912+nightly
    • Stabile Softwareversionen:
      • JAX + JAX-Bibliothek der Version 0.4.37
  • Prüfen Sie, ob Ihr Projekt ein ausreichendes TPU-Kontingent für Folgendes hat:

    • TPU-VM-Kontingent
    • Kontingent für IP-Adressen
    • Hyperdisk-Balance-Kontingent

  • Nutzerprojektberechtigungen

Umgebungsvariablen

Erstellen Sie in Cloud Shell die folgenden Umgebungsvariablen:

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for provisioning Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Beschreibung der Befehls-Flags

Variable Beschreibung
NODE_ID Die vom Nutzer zugewiesene ID der TPU, die erstellt wird, wenn die anstehende Ressourcenanfrage zugewiesen wird.
PROJECT_ID Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter
ZONE Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen.
ACCELERATOR_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Das ist die E-Mail-Adresse Ihres Dienstkontos. Sie finden sie in der Google Cloud Console unter „IAM“ -> „Dienstkonten“.

Beispiel: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

NUM_SLICES Die Anzahl der zu erstellenden Schichten (nur für Mehrfachaufnahmen erforderlich).
QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der anstehenden Ressourcenanfrage.
VALID_DURATION Die Dauer, für die die angeforderte Ressource gültig ist.
NETWORK_NAME Der Name eines sekundären Netzwerks, das verwendet werden soll.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

Optimierungen der Netzwerkleistung

Für die beste Leistung verwenden Sie ein Netzwerk mit 8.896 MTU (maximale Übertragungseinheit).

Standardmäßig bietet eine Virtual Private Cloud (VPC) nur eine MTU von 1.460 Byte, was zu einer suboptimalen Netzwerkleistung führt. Sie können die MTU eines VPC-Netzwerk auf einen beliebigen Wert zwischen 1.300 Byte und 8.896 Byte (einschließlich) festlegen. Gängige benutzerdefinierte MTU-Größen sind 1.500 Byte (Ethernet-Standard) oder 8.896 Byte (maximal möglich). Weitere Informationen finden Sie unter Gültige MTU-VPC-Netzwerk-Netzwerke.

Weitere Informationen zum Ändern der MTU-Einstellung für ein vorhandenes oder Standardnetzwerk finden Sie unter MTU-Einstellung eines VPC-Netzwerks ändern.

Im folgenden Beispiel wird ein Netzwerk mit 8.896 MTU erstellt.

export RESOURCE_NAME=RESOURCE_NAME
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
export PROJECT=X
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME}
 --allow tcp,icmp,udp --project=${PROJECT}

Mehrere NICs verwenden (Option für Multi-Slice)

Die folgenden Umgebungsvariablen sind für ein sekundäres Subnetz erforderlich, wenn Sie eine Multislice-Umgebung verwenden.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2

Verwenden Sie die folgenden Befehle, um eine benutzerdefinierte IP-Weiterleitung für das Netzwerk und das Subnetz zu erstellen.

gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 --region="${REGION}" \
   --project=$PROJECT

gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create "${ROUTER_NAME}" \
  --project="${PROJECT_ID}" \
  --network="${NETWORK_NAME_2}" \
  --region="${REGION}"

gcloud compute routers nats create "${NAT_CONFIG}" \
  --router="${ROUTER_NAME}" \
  --region="${REGION}" \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project="${PROJECT_ID}" \
  --enable-logging

Nachdem ein Multi-Network Slice erstellt wurde, können Sie prüfen, ob beide NICs verwendet werden. Dazu richten Sie einen XPK-Cluster ein und führen --command ifconfig als Teil der XPK-Arbeitslast aus.

Verwenden Sie den folgenden xpk workload-Befehl, um die Ausgabe des Befehls ifconfig in den Cloud Console-Protokollen anzuzeigen, und prüfen Sie, ob sowohl eth0 als auch eth1 die mtu=8896 haben.

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   (--base-docker-image maxtext_base_image|--docker-image CLOUD_IMAGE_NAME \
   --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone $ZONE \
   --project $PROJECT_ID \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

Prüfen Sie, ob sowohl eth0 als auch eth1 die mtu=8.896 haben. Sie können prüfen, ob Multi-NIC verwendet wird, indem Sie den Befehl „–command ifconfig“ als Teil der XPK-Arbeitslast ausführen. Sehen Sie sich dann die gedruckte Ausgabe dieser xpk-Arbeitslast in den Cloud Console-Protokollen an und prüfen Sie, ob sowohl eth0 als auch eth1 die mtu=8896 haben.

Verbesserte TCP-Einstellungen

Bei TPUs, die über die Benutzeroberfläche für Ressourcen in der Warteschlange erstellt wurden, können Sie den folgenden Befehl ausführen, um die Netzwerkleistung zu verbessern, indem Sie die TCP-Empfangsbufferlimits erhöhen.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "$PROJECT" \
  --zone "$ZONE" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

Bereitstellung mit in die Warteschlange gestellten Ressourcen

Die zugewiesene Kapazität kann mit dem Befehl create für Ressourcen in der Warteschlange bereitgestellt werden.

  1. Erstellen Sie eine Anfrage für eine in die Warteschlange gestellte TPU-Ressource.

    Das --reserved-Flag ist nur für reservierte Ressourcen erforderlich, nicht für On-Demand-Ressourcen.

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --valid-until-duration ${VALID_DURATION} \
      --service-account ${SERVICE_ACCOUNT} \
      [--reserved]
    
      # The following flags are only needed if you are using Multislice.
      --node-count node-count  # Number of slices in a Multislice \
      --node-prefix node-prefix # An optional user-defined node prefix;
       the default is QUEUED_RESOURCE_ID.

    Wenn die Anfrage für die Ressourcenwarteschlange erfolgreich erstellt wurde, hat das Feld „response“ entweder den Status „WAITING_FOR_RESOURCES“ oder „FAILED“. Wenn die in der Warteschlange befindliche Ressourcenanfrage den Status „WAITING_FOR_RESOURCES“ (Warten auf Ressourcen) hat, wurde die Ressource der Warteschlange hinzugefügt und wird bereitgestellt, sobald genügend zugewiesene TPU-Kapazität vorhanden ist. Wenn die Anfrage für die Ressourcenwarteschlange den Status „FAILED“ (Fehlgeschlagen) hat, wird der Grund für den Fehler in der Ausgabe angezeigt. Die anstehende Ressourcenanfrage läuft ab, wenn innerhalb der angegebenen Dauer kein v6e bereitgestellt wird. Der Status ändert sich dann in „FAILED“. Weitere Informationen finden Sie in der öffentlichen Dokumentation zu Ressourcen in der Warteschlange.

    Wenn Ihre in die Warteschlange gestellte Ressourcenanfrage den Status „AKTIV“ hat, können Sie über SSH eine Verbindung zu Ihren TPU-VMs herstellen. Verwenden Sie die Befehle list oder describe, um den Status der in der Warteschlange befindlichen Ressource abzufragen.

    gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project ${PROJECT_ID} --zone ${ZONE}
    

    Wenn sich die erwartete Ressource im Status „AKTIV“ befindet, sieht die Ausgabe in etwa so aus:

      state:
       state: ACTIVE
    
  2. TPU-VMs verwalten Informationen zu den Optionen zum Verwalten Ihrer TPU-VMs finden Sie unter TPU-VMs verwalten.

  3. Über SSH eine Verbindung zu TPU-VMs herstellen

    Sie können Binärdateien auf jeder TPU-VM in Ihrem TPU-Speicherplatz installieren und Code ausführen. Im Abschnitt VM-Typen erfahren Sie, wie viele VMs Ihr Slice haben wird.

    Wenn Sie die Binärdateien installieren oder Code ausführen möchten, können Sie über SSH mit dem Befehl tpu-vm ssh eine Verbindung zu einer VM herstellen.

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --node=all # add this flag if you are using Multislice
    

    Wenn Sie über SSH eine Verbindung zu einer bestimmten VM herstellen möchten, verwenden Sie das Flag --worker, das einem Index mit der Basis 0 folgt:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
    

    Wenn Sie Slice-Formen mit mehr als 8 Chips haben, befinden sich mehrere VMs in einem Slice. Verwenden Sie in diesem Fall die Parameter --worker=all und --command in Ihrem gcloud alpha compute tpus tpu-vm ssh-Befehl, um einen Befehl gleichzeitig auf allen VMs auszuführen. Beispiel:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
      --zone  ${ZONE} --worker=all \
      --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. In die Warteschlange gestellte Ressource löschen

    Lösche am Ende der Sitzung eine in die Warteschlange gestellte Ressource oder entferne in die Warteschlange gestellte Ressourcenanfragen, die den Status „FAILED“ (Fehlgeschlagen) haben. Wenn Sie eine in die Warteschlange gestellte Ressource löschen möchten, löschen Sie in zwei Schritten den Ausschnitt und dann die Anfrage für die in die Warteschlange gestellte Ressource:

    gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \
     --zone=${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
     --project ${PROJECT_ID} --zone ${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
    

TPUs der Version 6e mit GKE oder XPK bereitstellen

Wenn Sie GKE-Befehle mit v6e verwenden, können Sie Kubernetes-Befehle oder XPK verwenden, um TPUs bereitzustellen und Modelle zu trainieren oder bereitzustellen. Unter TPUs in GKE planen erfahren Sie, wie Sie Ihre TPU-Konfigurationen in GKE-Clustern planen. In den folgenden Abschnitten finden Sie Befehle zum Erstellen eines XPK-Clusters mit Unterstützung für eine einzelne NIC und mehrere NICs.

Befehle zum Erstellen eines XPK-Clusters mit Unterstützung für eine einzelne NIC

export CLUSTER_NAME xpk-cluster-name
export ZONE=us-central2-b
export PROJECT=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
   gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
   gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network ${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
   python3 xpk.py cluster create --cluster $CLUSTER_NAME \
   --cluster-cpu-machine-type=n1-standard-8 \
   --num-slices=$NUM_SLICES \
   --tpu-type=$TPU_TYPE \
   --zone=$ZONE  \
   --project=$PROJECT \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

Beschreibung der Befehls-Flags

Variable Beschreibung
CLUSTER_NAME Der vom Nutzer zugewiesene Name für den XPK-Cluster.
PROJECT_ID Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter
ZONE Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen.
TPU_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
NUM_SLICES Die Anzahl der Scheiben, die Sie erstellen möchten
CLUSTER_ARGUMENTS Das zu verwendende Netzwerk und Subnetzwerk.

Beispiel: „--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}“

NUM_SLICES Die Anzahl der zu erstellenden Segmente.
NETWORK_NAME Der Name eines sekundären Netzwerks, das verwendet werden soll.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

Befehle zum Erstellen eines XPK-Clusters mit Unterstützung mehrerer NICs

export CLUSTER_NAME xpk-cluster-name
export ZONE=us-central2-b
export PROJECT=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export exportSUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
   gcloud compute networks create "${NETWORK_NAME_1}" \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=$PROJECT
   gcloud compute networks subnets create "${SUBNET_NAME_1}" \
   --network="${NETWORK_NAME_1}" \
   --range=10.11.0.0/18 \
   --region="${REGION}" \
   --project=$PROJECT
   gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_1}" \
   --allow tcp,icmp,udp \
   --project="${PROJECT}"
  gcloud compute routers create "${ROUTER_NAME}" \
    --project="${PROJECT}" \
    --network="${NETWORK_NAME_1}" \
    --region="${REGION}"
  gcloud compute routers nats create "${NAT_CONFIG}" \
     --router="${ROUTER_NAME}" \
     --region="${REGION}" \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project="${PROJECT}" \
     --enable-logging
Secondary subnet for multi-nic experience. Need custom ip routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
   gcloud compute networks create "${NETWORK_NAME_2}" \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=$PROJECT
   gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 \
   --region="${REGION}" \
   --project=$PROJECT
   gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" \
   --allow tcp,icmp,udp \
   --project="${PROJECT}"
   gcloud compute routers create "${ROUTER_NAME}" \
     --project="${PROJECT}" \
     --network="${NETWORK_NAME_2}" \
     --region="${REGION}"
   gcloud compute routers nats create "${NAT_CONFIG}" \
     --router="${ROUTER_NAME}" \
     --region="${REGION}" \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project="${PROJECT}" \
     --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster $CLUSTER_NAME \
--num-slices=$NUM_SLICES \
--tpu-type=$TPU_TYPE \
--zone=$ZONE  \
--project=$PROJECT \
--on-demand \
--custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
--custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
--create-vertex-tensorboard

Beschreibung der Befehls-Flags

Variable Beschreibung
CLUSTER_NAME Der vom Nutzer zugewiesene Name für den XPK-Cluster.
PROJECT_ID Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter
ZONE Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen.
TPU_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
NUM_SLICES Die Anzahl der Scheiben, die Sie erstellen möchten
CLUSTER_ARGUMENTS Das zu verwendende Netzwerk und Subnetzwerk.

Beispiel: „--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}“

NODE_POOL_ARGUMENTS Zu verwendendes zusätzliches Knotennetzwerk.

Beispiel: „--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}“

NUM_SLICES Die Anzahl der zu erstellenden Schichten (nur für Mehrfachaufnahmen erforderlich).
NETWORK_NAME Der Name eines sekundären Netzwerks, das verwendet werden soll.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

Framework einrichten

In diesem Abschnitt wird die allgemeine Einrichtung für das Training von ML-Modellen mit den Frameworks JAX, PyTorch oder TensorFlow beschrieben. Sie können TPUs mithilfe von Ressourcen in der Warteschlange oder GKE bereitstellen. Die GKE-Einrichtung kann mit XPK- oder Kubernetes-Befehlen erfolgen.

Einrichtung für JAX

In diesem Abschnitt finden Sie Beispiele für die Ausführung von JAX-Arbeitslasten in GKE mit oder ohne XPK sowie für die Verwendung von Ressourcen in der Warteschlange.

JAX mit GKE einrichten

Im folgenden Beispiel wird ein einzelner 2 × 2-Host mit einer Kubernetes-YAML-Datei eingerichtet.

Einzelnes Slice auf einem einzelnen Host

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

Nach dem erfolgreichen Abschluss sollte im GKE-Log die folgende Meldung angezeigt werden:

Total TPU chips: 4

Einzelnes Slice auf mehreren Hosts

Im folgenden Beispiel wird ein 4 × 4-Knotenpool mit mehreren Hosts mithilfe einer Kubernetes-YAML-Datei eingerichtet.

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

Nach dem erfolgreichen Abschluss sollte im GKE-Log die folgende Meldung angezeigt werden:

Total TPU chips: 16

Multislice auf mehreren Hosts

Im folgenden Beispiel werden zwei 4 × 4-Multihost-Knotenpools mit einer Kubernetes-YAML-Datei eingerichtet.

Als Voraussetzung müssen Sie JobSet v0.2.3 oder höher installieren.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

Nach dem erfolgreichen Abschluss sollte im GKE-Log die folgende Meldung angezeigt werden:

Total TPU chips: 32

Weitere Informationen finden Sie in der GKE-Dokumentation unter Multi-Slice-Arbeitslast ausführen.

Aktivieren Sie für eine bessere Leistung hostNetwork.

Mehrere NICs

Damit Sie in GKE mehrere NICs nutzen können, muss das Kubernetes-Pod-Manifest zusätzliche Anmerkungen enthalten. Im Folgenden finden Sie ein Beispielmanifest für eine Arbeitslast mit mehreren NICs ohne TPU.

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

Wenn Sie exec in den Kubernetes-Pod eingeben, sollte die zusätzliche NIC mit dem folgenden Code angezeigt werden.

$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

JAX mit GKE und XPK einrichten

Ein Beispiel finden Sie in der xpk-README-Datei.

Informationen zum Einrichten und Ausführen von XPK mit MaxText finden Sie unter MaxText ausführen.

JAX mit Ressourcen in der Warteschlange einrichten

Mit gcloud alpha compute tpus tpu-vm ssh können Sie JAX gleichzeitig auf allen TPU-VMs in Ihrem Slice oder Ihren Slices installieren. Fügen Sie für Multislice --node=all hinzu.


gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'

Sie können den folgenden Python-Code ausführen, um zu prüfen, wie viele TPU-Kerne in Ihrem Slice verfügbar sind, und um zu testen, ob alles richtig installiert ist. Die hier gezeigten Ausgaben wurden mit einem v6e-16-Speicher erstellt:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

Die Ausgabe sieht in etwa so aus:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() gibt die Gesamtzahl der Chips im angegebenen Slice an. jax.local_device_count() gibt die Anzahl der Chips an, auf die eine einzelne VM in diesem Slice zugreifen kann.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   && pip install -r requirements.txt  && pip install . '

Fehlerbehebung bei JAX-Einrichtungen

Als allgemeinen Tipp können wir Ihnen empfehlen, im Manifest Ihrer GKE-Arbeitslast ausführliches Logging zu aktivieren. Reichen Sie die Protokolle dann beim GKE-Support ein.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Fehlermeldungen

no endpoints available for service 'jobset-webhook-service'

Dieser Fehler bedeutet, dass der Jobsatz nicht richtig installiert wurde. Prüfen Sie, ob die Kubernetes-Pods der Bereitstellung „jobset-controller-manager“ ausgeführt werden. Weitere Informationen finden Sie in der Dokumentation zur Fehlerbehebung bei JobSets.

TPU initialization failed: Failed to connect

Die GKE-Knotenversion muss 1.30.4-gke.1348000 oder höher sein. GKE 1.31 wird nicht unterstützt.

Einrichtung für PyTorch

In diesem Abschnitt wird beschrieben, wie Sie PJRT unter v6e mit PyTorch/XLA verwenden. Python 3.10 ist die empfohlene Python-Version.

PyTorch mit GKE und XPK einrichten

Sie können den folgenden Docker-Container mit XPK verwenden, in dem die PyTorch-Abhängigkeiten bereits installiert sind:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Führen Sie den folgenden Befehl aus, um eine XPK-Arbeitslast zu erstellen:

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES}  \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

Mit --base-docker-image wird ein neues Docker-Image erstellt, in das das aktuelle Arbeitsverzeichnis eingebunden ist.

PyTorch mit in die Warteschlange gestellten Ressourcen einrichten

Führen Sie die folgenden Schritte aus, um PyTorch mit Ressourcen in der Warteschlange zu installieren und ein kleines Script auf v6e auszuführen.

Abhängigkeiten über SSH installieren, um auf die VMs zuzugreifen

Fügen Sie für Multislice --node=all hinzu:

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Leistung von Modellen mit großen, häufigen Zuweisungen verbessern

Bei Modellen mit großen, häufigen Zuweisungen haben wir festgestellt, dass die Verwendung von tcmalloc die Leistung im Vergleich zur Standardimplementierung von malloc erheblich verbessert. Daher ist tcmalloc die Standard-malloc auf der TPU-VM. Je nach Arbeitslast kann tcmalloc jedoch zu einer Verlangsamung führen, z. B. bei DLRM mit sehr großen Zuweisungen für die Einbettungstabellen. In diesem Fall können Sie versuchen, die folgende Variable zurückzusetzen und stattdessen die Standardeinstellung malloc zu verwenden:

unset LD_PRELOAD

Mit einem Python-Script eine Berechnung auf einer v6e-VM ausführen:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

Dadurch wird eine Ausgabe generiert, die etwa so aussieht:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

Einrichtung für TensorFlow

Für die öffentliche Vorschau von v6e wird nur die Laufzeitversion „tf-nightly“ unterstützt.

Sie können tpu-runtime mit der mit v6e kompatiblen TensorFlow-Version zurücksetzen, indem Sie die folgenden Befehle ausführen:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Verwenden Sie SSH, um auf worker-0 zuzugreifen:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

Installieren Sie TensorFlow auf worker-0:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Exportieren Sie die Umgebungsvariable TPU_NAME:

export TPU_NAME=v6e-16

Mit dem folgenden Python-Script können Sie prüfen, wie viele TPU-Kerne in Ihrem Slice verfügbar sind, und testen, ob alles richtig installiert ist. Die angezeigten Ausgaben wurden mit einem v6e-16-Speicherplatz erstellt:

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

Die Ausgabe sieht in etwa so aus:

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e mit SkyPilot

Sie können TPU v6e mit SkyPilot verwenden. Führen Sie die folgenden Schritte aus, um SkyPilot v6e-bezogene Standort-/Preisinformationen hinzuzufügen.

  1. Fügen Sie am Ende von ~/.sky/catalogs/v5/gcp/vms.csv Folgendes hinzu :

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. Geben Sie die folgenden Ressourcen in einer YAML-Datei an:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. Cluster mit TPU v6e starten:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. Stellen Sie eine SSH-Verbindung zur TPU v6e her: ssh tpu_v6

Anleitungen für Inferenz

In den folgenden Anleitungen wird gezeigt, wie Sie die Inferenz auf einer TPU v6e ausführen:

Trainingsbeispiele

In den folgenden Abschnitten finden Sie Beispiele für das Training von MaxText-, MaxDiffusion- und PyTorch-Modellen auf TPU v6e.

MaxText- und MaxDiffusion-Training auf einer v6e-Cloud TPU-VM

In den folgenden Abschnitten wird der Trainingszyklus der Modelle MaxText und MaxDiffusion beschrieben.

Im Allgemeinen sind dies die allgemeinen Schritte:

  1. Erstellen Sie das Basis-Image der Arbeitslast.
  2. Führen Sie Ihre Arbeitslast mit XPK aus.
    1. Erstellen Sie den Trainingsbefehl für die Arbeitslast.
    2. Stellen Sie die Arbeitslast bereit.
  3. Arbeitslast verfolgen und Messwerte ansehen
  4. Löschen Sie die XPK-Arbeitslast, wenn sie nicht benötigt wird.
  5. Löschen Sie den XPK-Cluster, wenn er nicht mehr benötigt wird.

Basis-Image erstellen

Installieren Sie MaxText oder MaxDiffusion und erstellen Sie das Docker-Image:

  1. Klonen Sie das gewünschte Repository und wechseln Sie in das Verzeichnis für das Repository:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Konfigurieren Sie Docker so, dass die Google Cloud CLI verwendet wird:

    gcloud auth configure-docker
    
  3. Erstellen Sie das Docker-Image mit dem folgenden Befehl oder mit JAX Stable Stack. Weitere Informationen zum JAX Stable Stack finden Sie unter Docker-Image mit JAX Stable Stack erstellen.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.37
    
  4. Wenn Sie die Arbeitslast von einem Computer aus starten, auf dem das Image nicht lokal erstellt wurde, laden Sie das Image hoch:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
Docker-Image mit JAX Stable Stack erstellen

Sie können die Docker-Images „MaxText“ und „MaxDiffusion“ mit dem Basis-Image „JAX Stable Stack“ erstellen.

JAX Stable Stack bietet eine einheitliche Umgebung für MaxText und MaxDiffusion, indem JAX mit Kernpaketen wie orbax, flax und optax sowie einer gut qualifizierten libtpu.so kombiniert wird, die TPU-Programm-Dienstprogramme und andere wichtige Tools antreibt. Diese Bibliotheken werden auf Kompatibilität getestet und bieten eine stabile Grundlage für die Erstellung und Ausführung von MaxText und MaxDiffusion. So werden potenzielle Konflikte aufgrund von inkompatiblen Paketversionen vermieden.

Der stabile JAX-Stack enthält eine vollständig veröffentlichte und qualifizierte libtpu.so, die Kernbibliothek, die die Kompilierung, Ausführung und ICI-Netzwerkkonfiguration von TPU-Programmen steuert. Der libtpu-Release ersetzt den bisher von JAX verwendeten Nightly-Build und sorgt mit Qualifikationstests auf PJRT-Ebene in HLO/StableHLO-IRs für eine konsistente Funktionsweise von XLA-Berechnungen auf TPUs.

Wenn Sie das Docker-Image für MaxText und MaxDiffusion mit dem JAX Stable Stack erstellen möchten, legen Sie beim Ausführen des docker_build_dependency_image.sh-Scripts die Variable MODE auf stable_stack und die Variable BASEIMAGE auf das gewünschte Basis-Image fest.

Im folgenden Beispiel wird us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1 als Basis-Image angegeben:

bash docker_build_dependency_image.sh MODE=stable_stack
BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1

Eine Liste der verfügbaren JAX Stable Stack-Basis-Images finden Sie unter JAX Stable Stack-Images in Artifact Registry.

Arbeitslast mit XPK ausführen

  1. Legen Sie die folgenden Umgebungsvariablen fest, wenn Sie nicht die Standardwerte verwenden, die mit MaxText oder MaxDiffusion festgelegt wurden:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Erstellen Sie das Modellskript. Dieses Script wird in einem späteren Schritt als Trainingsbefehl kopiert.

    Führen Sie das Modellskript noch nicht aus.

    MaxText

    MaxText ist ein leistungsstarkes, hoch skalierbares Open-Source-LLM, das in reiner Python- und JAX-Programmierung geschrieben wurde und auf Google Cloud TPUs und GPUs für Training und Inferenz ausgerichtet ist.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma ist eine Familie von LLMs mit offenen Gewichten, die von Google DeepMind entwickelt wurden und auf der Gemini-Forschung und -Technologie basieren.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral ist ein hochmodernes KI-Modell, das von Mistral AI entwickelt wurde und eine sparse MoE-Architektur (Mixture of Experts) verwendet.

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama ist eine Familie von LLMs mit offenen Gewichten, die von Meta entwickelt wurden.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash"
    

    MaxDiffusion

    MaxDiffusion ist eine Sammlung von Referenzimplementierungen verschiedener latenter Diffusionsmodelle, die in reiner Python- und JAX-Programmierung geschrieben wurden und auf XLA-Geräten wie Cloud TPUs und GPUs ausgeführt werden. Stable Diffusion ist ein latentes Text-zu-Bild-Modell, das fotorealistische Bilder aus beliebigen Texteingaben generiert.

    Sie müssen einen bestimmten Git-Branch installieren, um MaxDiffusion auszuführen, wie im folgenden git checkout-Befehl gezeigt.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    Trainingsskript:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
    
        
  3. Führen Sie das Modell mit dem im vorherigen Schritt erstellten Script aus. Sie müssen entweder das Flag --base-docker-image angeben, um das MaxText-Basisbild zu verwenden, oder das Flag --docker-image und das gewünschte Bild.

    Optional: Sie können das Debug-Logging aktivieren, indem Sie das Flag --enable-debug-logs einfügen. Weitere Informationen finden Sie unter JAX in MaxText debuggen.

    Optional: Sie können einen Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen. Fügen Sie dazu das Flag --use-vertex-tensorboard hinzu. Weitere Informationen finden Sie unter JAX auf MaxText mit Vertex AI überwachen.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
        --tpu-type=$ACCELERATOR_TYPE \
        --num-slices=$NUM_SLICES  \
        --on-demand \
        --zone $ZONE \
        --project $PROJECT_ID \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command $YOUR-MODEL-SCRIPT

    Exportieren Sie die folgenden Variablen:

    export ClUSTER_NAME=CLUSTER_NAME: Der Name Ihres XPK-Clusters. export ACCELERATOR_TYPEACCELERATOR_TYPE: Die Version und Größe Ihrer TPU. Beispiel: v6e-256 export NUM_SLICES=NUM_SLICES: Die Anzahl der TPU-Scheiben. export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: Das Modellscript, das als Trainingsbefehl ausgeführt werden soll.

    Die Ausgabe enthält einen Link, über den Sie Ihre Arbeitslast verfolgen können, ähnlich dem folgenden:

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    Öffnen Sie den Link und klicken Sie auf den Tab Protokolle, um Ihre Arbeitslast in Echtzeit zu verfolgen.

JAX in MaxText debuggen

Verwenden Sie zusätzliche XPK-Befehle, um zu ermitteln, warum der Cluster oder die Arbeitslast nicht ausgeführt wird.

  • XPK-Arbeitslastliste
  • XPK-Inspektor
  • Aktivieren Sie beim Erstellen der XPK-Arbeitslast mit dem Flag --enable-debug-logs ausführliche Protokolle in Ihren Arbeitslastprotokollen.

JAX auf MaxText mit Vertex AI überwachen

Skalar- und Profildaten über das verwaltete TensorBoard von Vertex AI aufrufen

  1. Erhöhen Sie die Anzahl der Resource Management (CRUD)-Anfragen für die von Ihnen verwendete Zone von 600 auf 5.000. Bei kleinen Arbeitslasten mit weniger als 16 VMs ist das möglicherweise kein Problem.
  2. Installieren Sie Abhängigkeiten wie cloud-accelerator-diagnostics für Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Erstellen Sie Ihren XPK-Cluster mit dem Flag --create-vertex-tensorboard, wie unter Vertex AI TensorBoard erstellen beschrieben. Sie können diesen Befehl auch auf vorhandenen Clustern ausführen.

  4. Erstellen Sie Ihren Vertex AI-Test, wenn Sie Ihre XPK-Arbeitslast mit dem Flag --use-vertex-tensorboard und dem optionalen Flag --experiment-name ausführen. Eine vollständige Liste der Schritte finden Sie unter Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen.

Die Protokolle enthalten einen Link zu einem Vertex AI TensorBoard, ähnlich wie hier:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Sie finden den Link zu Vertex AI TensorBoard auch in der Google Cloud Console. Rufen Sie in der Google Cloud Console Vertex AI-Tests auf. Wählen Sie im Drop-down-Menü die gewünschte Region aus.

Das TensorBoard-Verzeichnis wird auch in den Cloud Storage-Bucket geschrieben, den Sie mit ${BASE_OUTPUT_DIR} angegeben haben.

XPK-Arbeitslasten löschen

Verwenden Sie den Befehl xpk workload delete, um eine oder mehrere Arbeitslasten basierend auf dem Jobpräfix oder Jobstatus zu löschen. Dieser Befehl kann nützlich sein, wenn Sie XPK-Arbeitslasten gesendet haben, die nicht mehr ausgeführt werden müssen, oder wenn Jobs in der Warteschlange hängen.

XPK-Cluster löschen

Verwenden Sie den Befehl xpk cluster delete, um einen Cluster zu löschen:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone $ZONE --project $PROJECT_ID

Llama- und PyTorch/XLA-Training auf einer v6e-Cloud TPU-VM

In dieser Anleitung wird beschrieben, wie Sie Llama-Modelle mit PyTorch/XLA auf einer TPU v6e mit dem Dataset WikiText trainieren.

Zugriff auf Hugging Face und das Llama 3-Modell erhalten

Sie benötigen ein Nutzerzugriffstoken für Hugging Face, um dieses Tutorial auszuführen. Informationen zum Erstellen und Verwenden von Nutzerzugriffstokens finden Sie in der Hugging Face-Dokumentation zu Nutzerzugriffstokens.

Außerdem benötigen Sie die Berechtigung, auf das Llama 3 8B-Modell auf Hugging Face zuzugreifen. Wenn Sie Zugriff erhalten möchten, rufen Sie das Meta-Llama-3-8B-Modell auf Hugging Face auf und beantragen Sie Zugriff.

TPU-VM erstellen

Erstellen Sie eine TPU v6e mit 8 Chips, um die Anleitung auszuführen.

  1. Richten Sie Umgebungsvariablen ein:

    export ACCELERATOR_TYPE=v6e-8
    export VERSION=v2-alpha-tpuv6e
    export TPU_NAME=$USER-$ACCELERATOR_TYPE
    export PROJECT=YOUR_PROJECT
    export ZONE=YOUR_ZONE
  2. So erstellen Sie eine TPU-VM:

    gcloud alpha compute tpus tpu-vm create $TPU_NAME --version=$VERSION \
        --accelerator-type=$ACCELERATOR_TYPE --zone=$ZONE --project=$PROJECT

Installation

Installieren Sie den pytorch-tpu/transformers-Fork von Hugging Face Transformers und die Abhängigkeiten. Diese Anleitung wurde mit den folgenden Abhängigkeitsversionen getestet, die in diesem Beispiel verwendet werden:

  • torch: kompatibel mit 2.5.0
  • torch_xla[tpu]: kompatibel mit 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT --zone $ZONE \
    --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    cd transformers
    sudo pip3 install -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Modellkonfigurationen einrichten

Der Trainingsbefehl im nächsten Abschnitt, Modell ausführen, verwendet zwei JSON-Konfigurationsdateien, um Modellparameter und die FSDP-Konfiguration (Fully Sharded Data Parallel) zu definieren. Das FSDP-Sharding wird für die Modellgewichte verwendet, damit sie während des Trainings zu einer größeren Batchgröße passen. Beim Training mit kleineren Modellen reicht es möglicherweise aus, Datenparallelität zu verwenden und die Gewichte auf jedem Gerät zu replizieren. Weitere Informationen zum Sharding von Tensoren auf Geräten in PyTorch/XLA finden Sie im PyTorch/XLA SPMD-Nutzerhandbuch.

  1. Erstellen Sie die Konfigurationsdatei für die Modellparameter. Im Folgenden finden Sie die Modellparameterkonfiguration für Llama3-8B. Für andere Modelle finden Sie die Konfiguration auf Hugging Face. Siehe beispielsweise die Llama2-7B-Konfiguration.

    cat > llama-config.json <
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
  2. Erstellen Sie die FSDP-Konfigurationsdatei:

    cat > fsdp-config.json <
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF

    Weitere Informationen zu FSDP finden Sie unter FSDPv2.

  3. Laden Sie die Konfigurationsdateien mit dem folgenden Befehl auf Ihre TPU-VMs hoch:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \
        --worker=all \
        --project=$PROJECT \
        --zone $ZONE

Modell ausführen

Führen Sie mit den Konfigurationsdateien, die Sie im vorherigen Abschnitt erstellt haben, das run_clm.py-Script aus, um das Llama 3 8B-Modell mit dem WikiText-Dataset zu trainieren. Das Training dauert auf einer TPU v6e-8 etwa 10 Minuten.

  1. Melden Sie sich mit dem folgenden Befehl auf Ihrer TPU in Hugging Face an:

    gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
        --zone $ZONE \
        --worker=all \
        --command='
        pip3 install "huggingface_hub[cli]"
        huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Modelltraining ausführen:

    gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
        --zone $ZONE \
        --worker=all \
        --command='
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=PROFILE_PATH
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

Fehlerbehebung bei PyTorch/XLA

Wenn Sie die optionalen Variablen für das Debuggen im vorherigen Abschnitt festgelegt haben, wird das Profil für das Modell an dem Speicherort gespeichert, der in der Variablen PROFILE_LOGDIR angegeben ist. Sie können die dort gespeicherte xplane.pb-Datei extrahieren und mit tensorboard die Profile in Ihrem Browser anzeigen lassen. Folgen Sie dazu der TensorBoard-Anleitung. Wenn PyTorch/XLA nicht wie erwartet funktioniert, lesen Sie den Leitfaden zur Fehlerbehebung. Dort finden Sie Vorschläge zum Debuggen, Profilieren und Optimieren Ihres Modells.

DLRM DCN v2-Training auf v6e

In dieser Anleitung wird beschrieben, wie Sie das DLRM DCN v2-Modell auf einer TPU v6e trainieren. Sie müssen eine TPU v6e mit 64, 128 oder 256 Chips bereitstellen.

Wenn Sie mehrere Hosts verwenden, setzen Sie tpu-runtime mit der entsprechenden TensorFlow-Version zurück. Führen Sie dazu den folgenden Befehl aus. Wenn Sie die Anwendung auf einem einzelnen Host ausführen, müssen Sie die folgenden beiden Befehle nicht ausführen.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

SSH-Verbindung zu worker-0 herstellen

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

TPU-Namen festlegen

export TPU_NAME=${TPU_NAME}

DLRM v2 ausführen

pip install --user setuptools==65.5.0

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

Führen Sie script.sh aus.

chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Die folgenden Flags sind erforderlich, um Empfehlungslasten (DLRM DCN) auszuführen:

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

Benchmarking-Ergebnisse

Der folgende Abschnitt enthält Benchmarking-Ergebnisse für DLRM DCN v2 und MaxDiffusion auf v6e.

DLRM DCN v2

Das DLRM DCN v2-Trainingsskript wurde in verschiedenen Skalen ausgeführt. Die Durchlaufraten finden Sie in der folgenden Tabelle.

v6e-64 v6e-128 v6e-256
Trainingsschritte 7000 7000 7000
Globale Batchgröße 131.072 262144 524.288
Durchsatz (Beispiele/Sek.) 2975334 5111808 10066329

MaxDiffusion

Wir haben das Trainingsskript für MaxDiffusion auf einem v6e-4, einem v6e-16 und einem 2xv6e-16 ausgeführt. Die Durchlaufraten finden Sie in der folgenden Tabelle.

v6e-4 v6e-16 Zwei v6e-16
Trainingsschritte 0.069 0,073 0,13
Globale Batchgröße 8 32 64
Durchsatz (Beispiele/Sek.) 115,9 438,4 492,3

Zeitplan für die Datenerhebung

Trillium (Version 6e) enthält eine neue Funktion namens „collection scheduling“ (Sammlungsplanung). Mit dieser Funktion können Sie mehrere TPU-Slices verwalten, auf denen eine -Inferenzarbeitslast mit einem einzelnen Host sowohl in GKE als auch in der Cloud TPU API ausgeführt wird. Wenn Sie diese Scheiben in einer Sammlung gruppieren, lässt sich die Anzahl der Repliken ganz einfach an die Nachfrage anpassen. Softwareupdates werden sorgfältig gesteuert, damit immer ein Teil der Chunks innerhalb der Sammlung für den Umgang mit eingehenden Zugriffen verfügbar ist.

Weitere Informationen zur Verwendung des Sammlungsplanungstools mit GKE finden Sie in der GKE-Dokumentation.

Die Funktion zum Planen der Datenerhebung gilt nur für Version 6e.

Datenerhebung mit der Cloud TPU API planen

Eine Sammlung mit einem einzelnen Host in der Cloud TPU API ist eine Ressourcenwarteschlange, für die ein spezielles Flag (--workload-type = availability-optimized) festgelegt ist, um der zugrunde liegenden Infrastruktur anzugeben, dass sie für die Bereitstellung von Arbeitslasten verwendet werden soll.

Mit dem folgenden Befehl wird eine Sammlung mit einem einzelnen Host mithilfe der Cloud TPU API bereitgestellt:

gcloud alpha compute tpus queued-resources create my-collection \
   --project=$PROJECT_ID \
   --zone=${ZONE} \
   --accelerator-type $ACCELERATOR_TYPE \
   --node-count ${NODE_COUNT} \
   --workload-type=availability-optimized

Überwachen und profilieren

Cloud TPU v6e unterstützt Monitoring und Profiling mit denselben Methoden wie frühere Cloud TPU-Generationen. Weitere Informationen zum Monitoring finden Sie unter TPU-VMs überwachen.