Modell mit TPU v6e trainieren

In diesem Dokument wird beschrieben, wie Sie Modelle auf Cloud TPU v6e (auch Trillium genannt) trainieren. Dabei werden die Einrichtung der Umgebung, die Leistungsoptimierung und praktische Trainingsbeispiele mit JAX und PyTorch/XLA behandelt.

TPU v6e, auch Trillium genannt, ist die 6. Generation von TPUs von Google. Auf allen technischen Oberflächen wie der API und in Logs sowie in diesem Dokument wird Trillium als v6e bezeichnet. Mit 256 Chips pro Pod weist die Architektur von TPU v6e viele Ähnlichkeiten mit v5e auf. TPU v6e ist für das Training, die Feinabstimmung und die Bereitstellung von Transformer-, Text-zu-Bild- und Convolutional Neural Networks (CNNs) optimiert. Weitere Informationen zur Systemarchitektur und zu den Konfigurationen von TPU v6e finden Sie unter TPU v6e.

Informationen zum Ausführen von Inferenz auf Cloud TPU v6e finden Sie in den folgenden Anleitungen:

Hinweise

Bevor Sie beginnen, müssen Sie Folgendes tun:

  • Google Cloud Konto und Projekt mit aktivierter Abrechnung erstellen
  • Google Cloud CLI-Alphakomponenten installieren
  • Cloud TPU API aktivieren
  • Cloud TPU-Dienst-Agent erstellen
  • Cloud TPU-Dienstkonto erstellen und Berechtigungen erteilen

Weitere Informationen finden Sie unter Cloud TPU-Umgebung einrichten.

Kontingent und Berechtigungen prüfen

Prüfen Sie, ob Ihr Projekt die folgenden Kontingente hat:

Wenn Sie GKE mit XPK verwenden, benötigen Sie zusätzliche Berechtigungen in der Google Cloud Konsole. Weitere Informationen finden Sie unter Erforderliche Berechtigungen in derGoogle Cloud -Konsole .

TPUs bereitstellen

Sie können TPU v6e mit den folgenden Methoden bereitstellen und verwalten:

  • GKE: Mit GKE können Sie TPUs als Pool von Beschleunigern für Ihre containerisierten ML-Arbeitslasten bereitstellen und verwalten. Weitere Informationen finden Sie unter TPUs in GKE.
  • GKE und XPK: XPK ist ein Befehlszeilentool, das die Clustererstellung und die Ausführung von Arbeitslasten in GKE vereinfacht. Sie wurde für ML-Experten entwickelt, die TPUs bereitstellen und Trainingsjobs ausführen möchten, ohne über umfassende Kubernetes-Kenntnisse verfügen zu müssen. Weitere Informationen finden Sie im XPK-GitHub-Repository.
  • Cloud TPU-Ressourcen in der Warteschlange: Mit Ressourcen in der Warteschlange können Sie TPU-Kapazität anfordern, die bereitgestellt wird, sobald sie verfügbar ist. Sie eignet sich ideal für Batchjobs und fehlertolerante Arbeitslasten, die in einer Warteschlange warten können. Sie können ein Zeitfenster für Ihre Anfrage angeben. Weitere Informationen finden Sie unter In die Warteschlange gestellte Ressourcen verwalten.

v6e-Cloud TPUs mit GKE und XPK bereitstellen

Wenn Sie GKE-Befehle mit v6e verwenden, können Sie Cloud TPUs mit Kubernetes-Befehlen oder XPK bereitstellen und Modelle trainieren oder bereitstellen. Unter Cloud TPUs in GKE planen erfahren Sie, wie Sie Ihre Cloud TPU-Konfigurationen in GKE-Clustern planen. In den folgenden Abschnitten finden Sie Befehle zum Erstellen eines XPK-Clusters mit Unterstützung für einzelne NICs und mehrere NICs.

XPK-Cluster mit Unterstützung für eine einzelne NIC erstellen

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --create-vertex-tensorboard

Beschreibung der Befehls-Flags

Variable Beschreibung
CLUSTER_NAME Der vom Nutzer zugewiesene Name für den XPK-Cluster.
PROJECT_ID Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. Weitere Informationen finden Sie unter Google Cloud -Projekt einrichten.
ZONE Informationen zu den unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und -Zonen.
TPU_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
NUM_SLICES Die Anzahl der Segmente, die Sie erstellen möchten
CLUSTER_ARGUMENTS Das zu verwendende Netzwerk und Subnetzwerk.

Beispiel: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES Die Anzahl der zu erstellenden Slices.
NETWORK_NAME Der Name eines zu verwendenden sekundären Netzwerks.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

XPK-Cluster mit Unterstützung für mehrere NICs erstellen

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_1} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
   --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
   --create-vertex-tensorboard

Beschreibung der Befehls-Flags

Variable Beschreibung
CLUSTER_NAME Der vom Nutzer zugewiesene Name für den XPK-Cluster.
PROJECT_ID Google Cloud Projektname. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. Weitere Informationen finden Sie unter Google Cloud -Projekt einrichten.
ZONE Informationen zu den unterstützten Zonen finden Sie im Dokument Cloud TPU-Regionen und -Zonen.
TPU_TYPE Weitere Informationen finden Sie unter Beschleunigertypen.
NUM_SLICES Die Anzahl der Segmente, die Sie erstellen möchten
CLUSTER_ARGUMENTS Das zu verwendende Netzwerk und Subnetzwerk.

Beispiel: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS Zusätzliches Knotennetzwerk, das verwendet werden soll.

Beispiel: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES Die Anzahl der zu erstellenden Slices (nur für Multislice erforderlich).
NETWORK_NAME Der Name eines zu verwendenden sekundären Netzwerks.
NETWORK_FW_NAME Der Name einer sekundären Netzwerk-Firewall, die verwendet werden soll.

JAX oder PyTorch einrichten

In den folgenden Ressourcen wird beschrieben, wie Sie JAX oder PyTorch auf Ihrer Cloud TPU einrichten, je nachdem, welche Bereitstellungs- und Verwaltungsmethode Sie verwenden:

Informationen zum Einrichten und Ausführen von XPK mit MaxText finden Sie unter MaxText mit XPK im großen Maßstab ausführen .

Netzwerkleistung optimieren

In diesem Abschnitt wird beschrieben, wie Sie die Netzwerkleistung optimieren, indem Sie die maximale Übertragungseinheit (MTU) konfigurieren, mehrere NICs für Multislice-Umgebungen verwenden und die TCP-Einstellungen verbessern.

MTU konfigurieren

Für die beste Netzwerkleistung sollten Sie ein Netzwerk mit einer MTU (maximale Übertragungseinheit) von 8.896 verwenden.

Standardmäßig bietet eine Virtual Private Cloud (VPC) nur eine MTU von 1.460 Byte, was zu einer suboptimalen Netzwerkleistung führt. Sie können die MTU eines VPC-Netzwerk auf einen beliebigen Wert zwischen 1.300 Byte und 8.896 Byte (einschließlich) festlegen. Gängige benutzerdefinierte MTU-Größen sind 1.500 Byte (Standard-Ethernet) oder 8.896 Byte (das maximal mögliche). Weitere Informationen finden Sie unter Gültige MTU-VPC-Netzwerk-Netzwerke.

Weitere Informationen zum Ändern der MTU-Einstellung für ein vorhandenes oder Standardnetzwerk finden Sie unter MTU-Einstellung eines VPC-Netzwerks ändern.

Im folgenden Beispiel wird ein Netzwerk mit einer MTU von 8.896 und einer entsprechenden Firewallregel erstellt, die TCP-, ICMP- und UDP-Traffic innerhalb des Netzwerks zulässt.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
    --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
    --allow tcp,icmp,udp --project=${PROJECT_ID}

Ersetzen Sie your-resource-name durch einen Basisnamen für das Netzwerk und die Firewall.

Multi-NIC-Option für Multislice verwenden

Wenn Sie eine Multislice-Umgebung verwenden, legen Sie die folgenden Umgebungsvariablen fest, die für ein sekundäres Subnetz erforderlich sind:

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Verwenden Sie die folgenden Befehle, um benutzerdefiniertes IP-Routing für das Netzwerk und das Subnetz zu erstellen.

  1. Erstellen Sie das sekundäre Netzwerk.

    gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
    --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
    
  2. Erstellen Sie ein Subnetzwerk für das sekundäre Netzwerk.

    gcloud compute networks subnets create ${SUBNET_NAME_2} \
    --network=${NETWORK_NAME_2} \
    --range=10.10.0.0/18 --region=${REGION} \
    --project=${PROJECT_ID}
    
  3. Erstellen Sie eine Firewallregel, die Traffic innerhalb des neuen Subnetzwerks zulässt.

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
    --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
    
  4. Erstellen Sie einen Cloud Router für das sekundäre Netzwerk.

    gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_2} \
    --region=${REGION}
    
  5. Erstellen Sie eine NAT-Konfiguration für den Cloud Router.

    gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
    

Nachdem Sie einen Multi-Network-Slice erstellt haben, können Sie prüfen, ob beide Netzwerkkarten (NICs) verwendet werden. Dazu richten Sie einen XPK-Cluster ein und fügen dem Befehl zum Erstellen von XPK-Arbeitslasten das Flag --command ifconfig hinzu.

  1. Verwenden Sie den folgenden workload create-Befehl, um die Ausgabe des ifconfig-Befehls in den Google Cloud -Konsolenlogs anzuzeigen und zu prüfen, ob für eth0 und eth1 die MTU auf 8.896 festgelegt ist.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --command "ifconfig"

    Wenn Sie Debug-Logs aktivieren oder Vertex AI TensorBoard verwenden möchten, fügen Sie dem Befehl die folgenden optionalen Argumente hinzu:

    --enable-debug-logs \
    --use-vertex-tensorboard
  2. Prüfen Sie,ob für eth0 und eth1 die MTU auf 8.896 festgelegt ist. Sehen Sie dazu in den Google Cloud -Konsolenlogs nach.

TCP-Einstellungen verbessern

Wenn Sie Ihre Cloud TPUs mit in die Warteschlange gestellten Ressourcen bereitgestellt haben, können Sie die Netzwerkleistung verbessern, indem Sie die TCP-Empfangspufferlimits erhöhen. Führen Sie dazu den folgenden Befehl aus.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
    --project "${PROJECT_ID}" \
    --zone "${ZONE}" \
    --node=all \
    --worker=all \
    --command='
    sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

Leistung der Speicherzuweisung optimieren

Die tcmalloc-Bibliothek wird standardmäßig auf Cloud TPU-VMs verwendet, um die Leistung von Modellen mit umfangreichen, häufigen Speicherzuweisungen zu verbessern. Dies wird über die Umgebungsvariable LD_PRELOAD konfiguriert.

Bei einigen Arbeitslasten (z. B. DLRM mit sehr großen Zuweisungen für Einbettungstabellen) kann tcmalloc jedoch zu einer Verlangsamung führen. In solchen Fällen können Sie zur Standardfunktion malloc zurückkehren, indem Sie die Variable LD_PRELOAD in Ihrer Shell-Sitzung vor dem Ausführen des Trainingsskripts aufheben:

unset LD_PRELOAD

SkyPilot verwenden

Sie können Cloud TPU v6e mit SkyPilot verwenden. SkyPilot ist ein Open-Source-Framework, das das Ausführen, Verwalten und Skalieren von KI-Arbeitslasten vereinfacht. Sie können SkyPilot v6e-bezogene Standort- und Preisinformationen hinzufügen. Weitere Informationen finden Sie im SkyPilot-Beispiel für TPU v6e.

Trainingsbeispiele

In den folgenden Abschnitten finden Sie Beispiele für das Training von MaxText-, MaxDiffusion- und PyTorch-Modellen auf Cloud TPU v6e.

Diese Beispiele wurden mit den folgenden Softwareversionen getestet:

  • Python 3.10 oder höher
  • Nightly-Softwareversionen:
    • Nächtlicher JAX-Wert 0.4.32.dev20240912
    • Nächtliche LibTPU-Version 0.1.dev20240912+nightly
  • Stabile Softwareversionen:
    • JAX + JAX-Bibliothek v0.4.37

MaxText und MaxDiffusion auf Cloud TPU v6e trainieren

In den folgenden Abschnitten wird der Trainingszyklus der Modelle MaxText und MaxDiffusion beschrieben.

Im Allgemeinen sind folgende Schritte erforderlich:

  1. Erstellen Sie das Basis-Image für die Arbeitslast.
  2. Führen Sie Ihre Arbeitslast mit XPK aus.
    1. Erstellen Sie den Trainingsbefehl für die Arbeitslast.
    2. Stellen Sie die Arbeitslast bereit.
  3. Arbeitslast verfolgen und Messwerte ansehen
  4. Löschen Sie die XPK-Arbeitslast, wenn sie nicht benötigt wird.
  5. Löschen Sie den XPK-Cluster, wenn er nicht mehr benötigt wird.

Basis-Image erstellen

Installieren Sie MaxText oder MaxDiffusion und erstellen Sie das Docker-Image:

  1. Klonen Sie das gewünschte Repository und wechseln Sie in das Verzeichnis des Repositorys:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. Konfigurieren Sie Docker für die Verwendung der Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Erstellen Sie das Docker-Image mit dem folgenden Befehl oder mit einem JAX AI-Image. Weitere Informationen zu JAX AI-Images finden Sie unter JAX AI-Images.

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. Legen Sie Ihre Projekt-ID in der aktiven gcloud CLI-Konfiguration fest:

    gcloud config set project ${PROJECT_ID}
    
  5. Wenn Sie die Arbeitslast von einem Computer aus starten, auf dem das Image nicht lokal erstellt wurde, laden Sie das Image hoch.

    1. Legen Sie die Umgebungsvariable CLOUD_IMAGE_NAME fest:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. Laden Sie das Bild hoch:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

Arbeitslast mit XPK ausführen

  1. Legen Sie die folgenden Umgebungsvariablen fest, wenn Sie nicht die von MaxText festgelegten Standardwerte oder MaxDiffusion verwenden:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Modellskript erstellen Dieses Skript wird in einem späteren Schritt als Trainingsbefehl kopiert.

    Führen Sie das Modellskript noch nicht aus.

    MaxText

    MaxText ist ein leistungsstarkes, hochgradig skalierbares Open-Source-LLM, das in reinem Python und JAX geschrieben wurde und auf Google Cloud TPUs und GPUs für Training und Inferenz ausgerichtet ist.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma ist eine Familie von LLMs mit offenen Gewichten, die von Google DeepMind entwickelt wurden und auf der Forschung und Technologie von Gemini basieren.

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral ist ein hochmodernes KI-Modell, das von Mistral AI entwickelt wurde und eine spärliche MoE-Architektur (Mixture of Experts) nutzt.

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama ist eine Familie von LLMs mit offenen Gewichten, die von Meta entwickelt wurden.

    Ein Beispiel für die Ausführung von Llama3 in PyTorch finden Sie unter torch_xla-Modelle im torchprime-Repository.

    MaxDiffusion

    MaxDiffusion ist eine Sammlung von Referenzimplementierungen verschiedener latenter Diffusionsmodelle, die in reinem Python und JAX geschrieben sind und auf XLA-Geräten wie Cloud TPUs und GPUs ausgeführt werden. Stable Diffusion ist ein latentes Text-zu-Bild-Modell, das fotorealistische Bilder aus beliebigen Texteingaben generiert.

    Sie müssen einen bestimmten Git-Branch installieren, um MaxDiffusion auszuführen. Das folgende Trainingsskript zeigt, wie das geht.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    && pip install -r requirements.txt && pip install .
    && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR}
    && python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR} \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash \
        run_name=sdxl-ddp-v6e
    
  3. Exportieren Sie die folgenden Variablen:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    CLUSTER_NAME Der Name Ihres XPK-Clusters.
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    NUM_SLICES Die Anzahl der TPU-Slices.
    YOUR_MODEL_SCRIPT Das Modellskript, das als Trainingsbefehl ausgeführt werden soll.
  4. Führen Sie das Modell mit dem Skript aus, das Sie im vorherigen Schritt erstellt haben. Sie müssen entweder das Flag --base-docker-image angeben, um das MaxText-Basis-Image zu verwenden, oder das Flag --docker-image und das gewünschte Image.

    Sie können die folgenden optionalen Flags hinzufügen:

    • Sie können das Debug-Logging aktivieren, indem Sie das Flag --enable-debug-logs einfügen. Weitere Informationen finden Sie unter JAX auf MaxText debuggen.
    • Sie können einen Vertex AI-Test erstellen, um Daten in Vertex AI TensorBoard hochzuladen. Dazu müssen Sie das Flag --use-vertex-tensorboard einfügen. Weitere Informationen finden Sie unter JAX auf MaxText mit Vertex AI überwachen.
    python3 xpk.py workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --command="${YOUR_MODEL_SCRIPT}"

    Die Ausgabe enthält einen Link, über den Sie Ihre Arbeitslast verfolgen können. Öffnen Sie den Link und klicken Sie auf den Tab Logs, um Ihre Arbeitslast in Echtzeit zu verfolgen.

JAX in MaxText debuggen

Verwenden Sie zusätzliche XPK-Befehle, um zu ermitteln, warum der Cluster oder die Arbeitslast nicht ausgeführt wird:

  • XPK-Arbeitslastliste
  • XPK-Prüftool
  • Aktivieren Sie das ausführliche Logging in Ihren Arbeitslastlogs mit dem Flag --enable-debug-logs, wenn Sie die XPK-Arbeitslast erstellen.

JAX auf MaxText mit Vertex AI überwachen

Damit Sie TensorBoard verwenden können, muss Ihrem Google Cloud -Nutzerkonto die Rolle aiplatform.user zugewiesen sein. Führen Sie den folgenden Befehl aus, um diese Rolle zuzuweisen:

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

Skalar- und Profiler-Daten über das von Vertex AI verwaltete TensorBoard ansehen.

  1. Erhöhen Sie die Resource Management (CRUD)-Anfragen für die Zone, die Sie verwenden, von 600 auf 5.000. Bei kleinen Arbeitslasten mit weniger als 16 VMs ist das möglicherweise kein Problem.

  2. Installieren Sie Abhängigkeiten wie cloud-accelerator-diagnostics für Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Erstellen Sie Ihren XPK-Cluster mit dem Flag --create-vertex-tensorboard, wie in Vertex AI TensorBoard erstellen beschrieben. Sie können diesen Befehl auch für vorhandene Cluster ausführen.

  4. Erstellen Sie Ihren Vertex AI-Test, wenn Sie Ihren XPK-Arbeitslast mit dem Flag --use-vertex-tensorboard und dem optionalen Flag --experiment-name ausführen. Eine vollständige Liste der Schritte finden Sie unter Vertex AI-Test zum Hochladen von Daten in Vertex AI TensorBoard erstellen.

Die Logs enthalten einen Link zu einem Vertex AI TensorBoard, ähnlich dem folgenden:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Sie können den Link zu Vertex AI TensorBoard auch in der Google Cloud Console aufrufen. Rufen Sie Vertex AI Experiments in der Google Cloud Konsole auf. Wählen Sie im Drop-down-Menü die gewünschte Region aus.

Das TensorBoard-Verzeichnis wird auch in den Cloud Storage-Bucket geschrieben, den Sie mit ${BASE_OUTPUT_DIR} angegeben haben.

XPK-Arbeitslast löschen

Verwenden Sie den Befehl xpk workload delete, um eine oder mehrere Arbeitslasten basierend auf dem Jobpräfix oder dem Jobstatus zu löschen. Dieser Befehl kann nützlich sein, wenn Sie XPK-Arbeitslasten gesendet haben, die nicht mehr ausgeführt werden müssen, oder wenn Jobs in der Warteschlange hängen bleiben.

XPK-Cluster löschen

Verwenden Sie den Befehl xpk cluster delete, um den Cluster zu löschen:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
    --zone=${ZONE} --project=${PROJECT_ID}

Benchmark-Ergebnisse für MaxDiffusion

Wir haben das Trainingsskript für MaxDiffusion auf einer v6e-4, einer v6e-16 und zwei v6e-16 ausgeführt. In der folgenden Tabelle sehen Sie die gemessenen Durchsätze.

v6e-4 v6e-16 Zwei v6e-16
Trainingsschritte 0.069 0,073 0,13
Globale Batchgröße 8 32 64
Durchsatz (Beispiele/s) 115.9 438,4 492,3

Llama-Modelle mit PyTorch/XLA auf Cloud TPU v6e trainieren

In diesem Abschnitt wird beschrieben, wie Sie Llama-Modelle mit PyTorch/XLA auf Cloud TPU v6e mit dem WikiText-Dataset trainieren.

Zugriff auf Hugging Face und das Llama 3-Modell erhalten

Für dieses Beispiel benötigen Sie ein Hugging Face-Nutzerzugriffstoken. Informationen zum Erstellen von Nutzerzugriffstokens finden Sie in der Hugging Face-Dokumentation zu Nutzerzugriffstokens.

Außerdem benötigen Sie die Berechtigung für den Zugriff auf das Modell „Llama-3-8B“ auf Hugging Face. Wenn Sie Zugriff erhalten möchten, rufen Sie das Meta-Llama-3-8B-Modell auf Hugging Face auf und beantragen Sie den Zugriff.

Cloud TPU-VM erstellen

Erstellen Sie für dieses Beispiel eine Cloud TPU v6e mit 8 Chips.

  1. Richten Sie Umgebungsvariablen ein:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east1-d
    export ACCELERATOR_TYPE=v6e-8
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    PROJECT_ID Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues.
    TPU_NAME Der Name der TPU.
    ZONE Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen.
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    RUNTIME_VERSION Die Softwareversion der Cloud TPU.

  2. Cloud TPU-VM erstellen:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

Installation

Installieren Sie den pytorch-tpu/transformers-Fork von Hugging Face Transformers und Abhängigkeiten. Dieses Beispiel wurde mit den folgenden Abhängigkeitsversionen getestet:

  • torch: kompatibel mit 2.5.0
  • torch_xla[tpu]: kompatibel mit 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

Modellkonfigurationsdateien einrichten

Im Trainingsbefehl im nächsten Abschnitt, Modell ausführen, werden zwei JSON-Konfigurationsdateien verwendet, um Modellparameter und die Konfiguration für Fully Sharded Data Parallel (FSDP) zu definieren. Mit FSDP-Sharding können Sie beim Training eine größere Batchgröße verwenden, da die Modellgewichte auf mehrere TPUs verteilt werden. Beim Training mit kleineren Modellen kann es ausreichen, Datenparallelität zu verwenden und die Gewichte auf jedem Gerät zu replizieren. Weitere Informationen zum Sharding von Tensoren auf Geräten in PyTorch/XLA finden Sie im SPMD-Nutzerhandbuch für PyTorch/XLA.

  1. Erstellen Sie die Konfigurationsdatei für Modellparameter. Im Folgenden finden Sie die Konfiguration der Modellparameter für Llama-3-8B. Die Konfigurationsdatei für andere Modelle finden Sie auf Hugging Face. Ein Beispiel finden Sie in der Llama-2-7B-Konfiguration.

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. Erstellen Sie die Konfigurationsdatei für FSDP:

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Weitere Informationen zu FSDP finden Sie unter Fully Sharded Data Parallel using SPMD .

  3. Laden Sie die Konfigurationsdateien mit dem folgenden Befehl auf Ihre Cloud TPU-VMs hoch:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

Modell ausführen

Führen Sie mit den Konfigurationsdateien, die Sie im vorherigen Abschnitt erstellt haben, das Skript run_clm.py aus, um das Llama-3-8B-Modell mit dem WikiText-Dataset zu trainieren. Die Ausführung des Trainingsskripts dauert auf einer Cloud TPU v6e-8 etwa 10 Minuten.

  1. Melden Sie sich mit dem folgenden Befehl auf Ihrer Cloud TPU bei Hugging Face an:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Modelltraining ausführen:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

Fehlerbehebung bei PyTorch/XLA

Wenn Sie die optionalen Variablen für das Debugging im vorherigen Abschnitt festgelegt haben, wird das Profil für das Modell am Speicherort gespeichert, der durch die Variable PROFILE_LOGDIR angegeben wird. Sie können die xplane.pb-Datei, die an diesem Speicherort gespeichert ist, extrahieren und tensorboard verwenden, um die Profile in Ihrem Browser gemäß der TensorBoard-Anleitung anzusehen.

Wenn PyTorch/XLA nicht wie erwartet funktioniert, finden Sie im Leitfaden zur Fehlerbehebung Vorschläge zum Debuggen, Profilerstellen und Optimieren Ihres Modells.