Modell mit TPU v5e trainieren

Mit einem kleineren Footprint von 256 Chips pro Pod ist TPU v5e für das Training, die Feinabstimmung und die Bereitstellung von Transformer-, Text-zu-Bild- und CNN-Modellen (Convolutional Neural Network) optimiert. Weitere Informationen zur Verwendung von Cloud TPU v5e für die Bereitstellung finden Sie unter Inferenz mit v5e.

Weitere Informationen zu Cloud TPU v5e-Hardware und -Konfigurationen finden Sie unter TPU v5e.

Jetzt starten

In den folgenden Abschnitten wird beschrieben, wie Sie mit der Verwendung von TPU v5e beginnen.

Anfragekontingent

Sie benötigen Kontingent, um TPU v5e für das Training zu verwenden. Es gibt verschiedene Kontingenttypen für On-Demand-TPUs, reservierte TPUs und TPU-Spot-VMs. Wenn Sie Ihre TPU v5e für Inferenz verwenden, sind separate Kontingente erforderlich. Weitere Informationen zu Kontingenten finden Sie unter Kontingente. Wenn Sie ein TPU v5e-Kontingent anfordern möchten, wenden Sie sich an den Cloud-Vertrieb.

Google Cloud -Konto und -Projekt erstellen

Sie benötigen ein Google Cloud -Konto und ein Projekt, um Cloud TPU zu verwenden. Weitere Informationen finden Sie unter Cloud TPU-Umgebung einrichten.

Cloud TPU erstellen

Als Best Practice gilt, Cloud TPU v5es als in die Warteschlange gestellte Ressourcen mit dem Befehl queued-resource create bereitzustellen. Weitere Informationen finden Sie unter In die Warteschlange gestellte Ressourcen verwalten.

Sie können auch die Create Node API (gcloud compute tpus tpu-vm create) verwenden, um Cloud TPU v5e-Knoten bereitzustellen. Weitere Informationen finden Sie unter TPU-Ressourcen verwalten.

Weitere Informationen zu den verfügbaren v5e-Konfigurationen für das Training finden Sie unter Cloud TPU v5e-Typen für das Training.

Framework einrichten

In diesem Abschnitt wird die allgemeine Einrichtung für das benutzerdefinierte Modelltraining mit JAX oder PyTorch mit TPU v5e beschrieben.

Eine Anleitung zur Einrichtung der Inferenz finden Sie unter Einführung in die Inferenz in Version 5e.

Definieren Sie einige Umgebungsvariablen:

export PROJECT_ID=your_project_ID
export ACCELERATOR_TYPE=v5litepod-16
export ZONE=us-west4-a
export TPU_NAME=your_tpu_name
export QUEUED_RESOURCE_ID=your_queued_resource_id

Einrichtung für JAX

Wenn Sie Slice-Formen mit mehr als 8 Chips haben, sind mehrere VMs in einem Slice vorhanden. In diesem Fall müssen Sie das Flag --worker=all verwenden, um die Installation in einem einzigen Schritt auf allen TPU-VMs auszuführen, ohne sich über SSH auf jeder einzelnen anzumelden:

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Beschreibung der Befehls-Flags

Variable Beschreibung
TPU_NAME Die vom Nutzer zugewiesene Text-ID der TPU, die erstellt wird, wenn die in die Warteschlange gestellte Ressourcenanfrage zugewiesen wird.
PROJECT_ID Google Cloud Projektname Vorhandenes Projekt verwenden oder neues Projekt erstellen unter Google Cloud -Projekt einrichten
ZONE Eine Liste der unterstützten Zonen finden Sie im Dokument TPU-Regionen und -Zonen.
Worker Die TPU-VM, die Zugriff auf die zugrunde liegenden TPUs hat.

Mit dem folgenden Befehl können Sie die Anzahl der Geräte prüfen. Die hier gezeigten Ausgaben wurden mit einem v5litepod-16-Slice erstellt. Mit diesem Code wird getestet, ob alles korrekt installiert ist. Dazu wird geprüft, ob JAX die Cloud TPU-TensorCores sieht und grundlegende Vorgänge ausführen kann:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

Die Ausgabe sollte in etwa so aussehen:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() gibt die Gesamtzahl der Chips im angegebenen Slice an. jax.local_device_count() gibt die Anzahl der Chips an, auf die eine einzelne VM in diesem Slice zugreifen kann.

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

Die Ausgabe sollte in etwa so aussehen:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

In diesem Dokument finden Sie JAX-Anleitungen, mit denen Sie mit dem Training von v5e-Modellen mit JAX beginnen können.

Einrichtung für PyTorch

Beachten Sie, dass v5e nur die PJRT-Laufzeit unterstützt. In PyTorch 2.1+ wird PJRT als Standardlaufzeit für alle TPU-Versionen verwendet.

In diesem Abschnitt wird beschrieben, wie Sie PJRT auf v5e mit PyTorch/XLA mit Befehlen für alle Worker verwenden.

Abhängigkeiten installieren

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip install mkl mkl-include
      pip install tf-nightly tb-nightly tbp-nightly
      pip install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

Ersetzen Sie PYTORCH_VERSION durch die PyTorch-Version, die Sie verwenden möchten. PYTORCH_VERSION wird verwendet, um dieselbe Version für PyTorch/XLA anzugeben. Version 2.6.0 wird empfohlen.

Weitere Informationen zu Versionen von PyTorch und PyTorch/XLA finden Sie unter PyTorch – Erste Schritte und PyTorch/XLA-Releases.

Weitere Informationen zur Installation von PyTorch/XLA finden Sie unter PyTorch/XLA installieren.

Wenn beim Installieren der Wheels für torch, torch_xla oder torchvision ein Fehler wie pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end or semicolon (after name and no valid version specifier) torch==nightly+20230222 auftritt, führen Sie ein Downgrade der Version mit diesem Befehl durch:

pip3 install setuptools==62.1.0

Skript mit PJRT ausführen

unset LD_PRELOAD

Das folgende Beispiel zeigt, wie mit einem Python-Script eine Berechnung auf einer v5e-VM durchgeführt wird:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      unset LD_PRELOAD
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

Dadurch wird eine Ausgabe generiert, die etwa so aussieht:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')

Mit den PyTorch-Tutorials in diesem Dokument können Sie mit dem Training von v5e-Modellen mit PyTorch beginnen.

Löschen Sie Ihre TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung. So löschen Sie eine Ressource in der Warteschlange:

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Mit diesen beiden Schritten können Sie auch in die Warteschlange gestellte Ressourcenanfragen entfernen, die sich im Status FAILED befinden.

JAX/FLAX-Beispiele

In den folgenden Abschnitten finden Sie Beispiele für das Trainieren von JAX- und FLAX-Modellen auf TPU v5e.

ImageNet auf v5e trainieren

In dieser Anleitung wird beschrieben, wie Sie ImageNet auf v5e mit gefälschten Eingabedaten trainieren. Wenn Sie echte Daten verwenden möchten, lesen Sie die README-Datei auf GitHub.

Einrichten

  1. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    PROJECT_ID Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues.
    TPU_NAME Der Name der TPU.
    ZONE Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen.
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    RUNTIME_VERSION Die Softwareversion der Cloud TPU.
    SERVICE_ACCOUNT Die E‑Mail-Adresse für Ihr Dienstkonto. Sie finden sie in der Google Cloud Console auf der Seite „Dienstkonten“.

    Beispiel: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der in die Warteschlange eingereihten Ressourcenanfrage.

  2. TPU-Ressource erstellen:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich Ihre in die Warteschlange gestellte Ressource im Status ACTIVE befindet:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Wenn sich die QueuedResource im Status ACTIVE befindet, sieht die Ausgabe in etwa so aus:

     state: ACTIVE
    
  3. Installieren Sie die neueste Version von JAX und jaxlib:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Klonen Sie das ImageNet-Modell und installieren Sie die entsprechenden Anforderungen:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
    
  5. Um gefälschte Daten zu generieren, benötigt das Modell Informationen zu den Dimensionen des Datasets. Diese Informationen können aus den Metadaten des ImageNet-Datasets abgerufen werden:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
    

Modell trainieren

Sobald alle vorherigen Schritte abgeschlossen sind, können Sie das Modell trainieren.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"

TPU und in die Warteschlange gestellte Ressource löschen

Löschen Sie Ihre TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Hugging Face FLAX-Modelle

Hugging Face-Modelle, die in FLAX implementiert sind, funktionieren sofort auf Cloud TPU v5e. In diesem Abschnitt finden Sie Anleitungen zum Ausführen beliebter Modelle.

ViT auf Imagenette trainieren

In dieser Anleitung erfahren Sie, wie Sie das Vision Transformer-Modell (ViT) von HuggingFace mit dem Fast AI-Dataset Imagenette auf Cloud TPU v5e trainieren.

Das ViT-Modell war das erste, mit dem ein Transformer-Encoder erfolgreich auf ImageNet trainiert wurde und das im Vergleich zu Convolutional Networks hervorragende Ergebnisse lieferte. Weitere Informationen finden Sie in den folgenden Ressourcen:

Einrichten

  1. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    PROJECT_ID Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues.
    TPU_NAME Der Name der TPU.
    ZONE Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen.
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    RUNTIME_VERSION Die Softwareversion der Cloud TPU.
    SERVICE_ACCOUNT Die E‑Mail-Adresse für Ihr Dienstkonto. Sie finden sie in der Google Cloud Console auf der Seite „Dienstkonten“.

    Beispiel: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der in die Warteschlange eingereihten Ressourcenanfrage.

  2. TPU-Ressource erstellen:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich Ihre in die Warteschlange gestellte Ressource im Status ACTIVE befindet:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Wenn sich die in die Warteschlange gestellte Ressource im Status ACTIVE befindet, sieht die Ausgabe in etwa so aus:

     state: ACTIVE
    
  3. Installieren Sie JAX und die zugehörige Bibliothek:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Laden Sie das Hugging Face-Repository herunter und installieren Sie die Anforderungen:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
    
  5. Laden Sie das Imagenette-Dataset herunter:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
    

Modell trainieren

Trainieren Sie das Modell mit einem vorab zugeordneten Puffer von 4 GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'

TPU und in die Warteschlange gestellte Ressource löschen

Löschen Sie Ihre TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Benchmark-Ergebnisse für ViT

Das Trainingsskript wurde auf v5litepod-4, v5litepod-16 und v5litepod-64 ausgeführt. In der folgenden Tabelle sind die Durchsätze für verschiedene Beschleunigertypen aufgeführt.

Beschleunigertyp v5litepod-4 v5litepod-16 v5litepod-64
Epoche 3 3 3
Globale Batchgröße 32 128 512
Durchsatz (Beispiele/s) 263,40 429,34 470,71

Diffusion auf Pokémon trainieren

In dieser Anleitung wird gezeigt, wie Sie das Stable Diffusion-Modell von HuggingFace mit dem Pokémon-Dataset auf Cloud TPU v5e trainieren.

Das Stable Diffusion-Modell ist ein latentes Text-zu-Bild-Modell, das fotorealistische Bilder aus beliebigen Texteingaben generiert. Weitere Informationen finden Sie in den folgenden Ressourcen:

Einrichten

  1. Legen Sie eine Umgebungsvariable für den Namen Ihres Speicher-Buckets fest:

    export GCS_BUCKET_NAME=your_bucket_name
  2. Speicher-Bucket für die Modellausgabe einrichten:

    gcloud storage buckets create gs://GCS_BUCKET_NAME \
        --project=your_project \
        --location=us-west1
  3. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west1-c
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    PROJECT_ID Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues.
    TPU_NAME Der Name der TPU.
    ZONE Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen.
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    RUNTIME_VERSION Die Softwareversion der Cloud TPU.
    SERVICE_ACCOUNT Die E‑Mail-Adresse für Ihr Dienstkonto. Sie finden sie in der Google Cloud Console auf der Seite „Dienstkonten“.

    Beispiel: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der in die Warteschlange eingereihten Ressourcenanfrage.

  4. TPU-Ressource erstellen:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich die in die Warteschlange gestellte Ressource im Status ACTIVE befindet:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Wenn sich die in die Warteschlange gestellte Ressource im Status ACTIVE befindet, sieht die Ausgabe in etwa so aus:

     state: ACTIVE
    
  5. Installieren Sie JAX und die zugehörige Bibliothek.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  6. Laden Sie das HuggingFace-Repository herunter und installieren Sie die Anforderungen.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
         --project=${PROJECT_ID} \
         --zone=${ZONE} \
         --worker=all \
         --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
    

Modell trainieren

Trainieren Sie das Modell mit einem vorab zugeordneten Puffer von 4 GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
    git clone https://github.com/google/maxdiffusion
    cd maxdiffusion
    pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    pip3 install -r requirements.txt
    pip3 install .
    pip3 install gcsfs
    export LIBTPU_INIT_ARGS=''
    python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
    jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
    per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
    output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"

Bereinigen

Löschen Sie am Ende der Sitzung Ihre TPU, die in die Warteschlange gestellte Ressource und den Cloud Storage-Bucket.

  1. Löschen Sie Ihre TPU:

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  2. Löschen Sie die in die Warteschlange gestellte Ressource:

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  3. Löschen Sie den Cloud Storage-Bucket:

    gcloud storage rm -r gs://${GCS_BUCKET_NAME}
    

Benchmark-Ergebnisse für Diffusion

Das Trainingsskript wurde auf v5litepod-4, v5litepod-16 und v5litepod-64 ausgeführt. In der folgenden Tabelle sind die Durchsätze aufgeführt.

Beschleunigertyp v5litepod-4 v5litepod-16 v5litepod-64
Trainingsschritt 1500 1500 1500
Globale Batchgröße 32 64 128
Durchsatz (Beispiele/s) 36,53 43,71 49.36

PyTorch/XLA

In den folgenden Abschnitten werden Beispiele für das Trainieren von PyTorch/XLA-Modellen auf TPU v5e beschrieben.

ResNet mit der PJRT-Laufzeit trainieren

PyTorch/XLA wird ab PyTorch 2.0 von XRT zu PjRt migriert. Hier finden Sie die aktualisierte Anleitung zum Einrichten von v5e für PyTorch/XLA-Trainingsarbeitslasten.

Einrichten
  1. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    PROJECT_ID Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues.
    TPU_NAME Der Name der TPU.
    ZONE Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen.
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    RUNTIME_VERSION Die Softwareversion der Cloud TPU.
    SERVICE_ACCOUNT Die E‑Mail-Adresse für Ihr Dienstkonto. Sie finden sie in der Google Cloud Console auf der Seite „Dienstkonten“.

    Beispiel: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der in die Warteschlange eingereihten Ressourcenanfrage.

  2. TPU-Ressource erstellen:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich die QueuedResource im Status ACTIVE befindet:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Wenn sich die in die Warteschlange gestellte Ressource im Status ACTIVE befindet, sieht die Ausgabe in etwa so aus:

     state: ACTIVE
    
  3. Torch/XLA-spezifische Abhängigkeiten installieren

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project=${PROJECT_ID} \
      --zone=${ZONE} \
      --worker=all \
      --command='
         sudo apt-get update -y
         sudo apt-get install libomp5 -y
         pip3 install mkl mkl-include
         pip3 install tf-nightly tb-nightly tbp-nightly
         pip3 install numpy
         sudo apt-get install libopenblas-dev -y
         pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

    Ersetzen Sie PYTORCH_VERSION durch die PyTorch-Version, die Sie verwenden möchten. PYTORCH_VERSION wird verwendet, um dieselbe Version für PyTorch/XLA anzugeben. Version 2.6.0 wird empfohlen.

    Weitere Informationen zu Versionen von PyTorch und PyTorch/XLA finden Sie unter PyTorch – Erste Schritte und PyTorch/XLA-Releases.

    Weitere Informationen zur Installation von PyTorch/XLA finden Sie unter PyTorch/XLA installieren.

ResNet-Modell trainieren
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      date
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export XLA_USE_BF16=1
      export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      git clone https://github.com/pytorch/xla.git
      cd xla/
      git checkout release-r2.6
      python3 test/test_train_mp_imagenet.py --model=resnet50  --fake_data --num_epochs=1 —num_workers=16  --log_steps=300 --batch_size=64 --profile'

TPU und in die Warteschlange gestellte Ressource löschen

Löschen Sie Ihre TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
Benchmark-Ergebnis

In der folgenden Tabelle sehen Sie die Benchmark-Durchsätze.

Beschleunigertyp Durchsatz (Beispiele/Sekunde)
v5litepod-4 4240 ex/s
v5litepod-16 10.810 ex/s
v5litepod-64 46.154 ex/s

ViT auf v5e trainieren

In dieser Anleitung wird beschrieben, wie Sie VIT auf v5e mit dem HuggingFace-Repository für PyTorch/XLA auf dem cifar10-Dataset ausführen.

Einrichten

  1. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    Beschreibungen von Umgebungsvariablen

    Variable Beschreibung
    PROJECT_ID Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues.
    TPU_NAME Der Name der TPU.
    ZONE Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen.
    ACCELERATOR_TYPE Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    RUNTIME_VERSION Die Softwareversion der Cloud TPU.
    SERVICE_ACCOUNT Die E‑Mail-Adresse für Ihr Dienstkonto. Sie finden sie in der Google Cloud Console auf der Seite „Dienstkonten“.

    Beispiel: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der in die Warteschlange eingereihten Ressourcenanfrage.

  2. TPU-Ressource erstellen:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich Ihre QueuedResource im Status ACTIVE befindet:

     gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Wenn sich die in die Warteschlange gestellte Ressource im Status ACTIVE befindet, sieht die Ausgabe in etwa so aus:

     state: ACTIVE
    
  3. PyTorch/XLA-Abhängigkeiten installieren

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip3 install mkl mkl-include
      pip3 install tf-nightly tb-nightly tbp-nightly
      pip3 install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
      pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

    Ersetzen Sie PYTORCH_VERSION durch die PyTorch-Version, die Sie verwenden möchten. PYTORCH_VERSION wird verwendet, um dieselbe Version für PyTorch/XLA anzugeben. Version 2.6.0 wird empfohlen.

    Weitere Informationen zu Versionen von PyTorch und PyTorch/XLA finden Sie unter PyTorch – Erste Schritte und PyTorch/XLA-Releases.

    Weitere Informationen zur Installation von PyTorch/XLA finden Sie unter PyTorch/XLA installieren.

  4. Laden Sie das HuggingFace-Repository herunter und installieren Sie die Anforderungen.

       gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="
          git clone https://github.com/suexu1025/transformers.git vittransformers; \
          cd vittransformers; \
          pip3 install .; \
          pip3 install datasets; \
          wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
    

Modell trainieren

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export TF_CPP_MIN_LOG_LEVEL=0
      export XLA_USE_BF16=1
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      cd vittransformers
      python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
      --remove_unused_columns=False \
      --label_names=pixel_values \
      --mask_ratio=0.75 \
      --norm_pix_loss=True \
      --do_train=true \
      --do_eval=true \
      --base_learning_rate=1.5e-4 \
      --lr_scheduler_type=cosine \
      --weight_decay=0.05 \
      --num_train_epochs=3 \
      --warmup_ratio=0.05 \
      --per_device_train_batch_size=8 \
      --per_device_eval_batch_size=8 \
      --logging_strategy=steps \
      --logging_steps=30 \
      --evaluation_strategy=epoch \
      --save_strategy=epoch \
      --load_best_model_at_end=True \
      --save_total_limit=3 \
      --seed=1337 \
      --output_dir=MAE \
      --overwrite_output_dir=true \
      --logging_dir=./tensorboard-metrics \
      --tpu_metrics_debug=true'

TPU und in die Warteschlange gestellte Ressource löschen

Löschen Sie Ihre TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Benchmark-Ergebnis

In der folgenden Tabelle sind die Benchmark-Durchsätze für verschiedene Beschleunigertypen aufgeführt.

v5litepod-4 v5litepod-16 v5litepod-64
Epoche 3 3 3
Globale Batchgröße 32 128 512
Durchsatz (Beispiele/s) 201 657 2.844