Cloud TPU v5e-Training

Cloud TPU v5e ist der KI-Beschleuniger der neuesten Generation von Google Cloud. Mit einem geringeren Platzbedarf von 256 Chips pro Pod ist v5e das Produkt mit dem höchsten Wert für das Training, die Feinabstimmung und die Bereitstellung von Transformern, Text-zu-Bild-Modellen und CNNs (Convolutional Neural Networks). Weitere Informationen zur Verwendung von Cloud TPU v5e für das Bereitstellen finden Sie unter Inferenz mit v5e.

Weitere Informationen zur TPU-Hardware und zu den Konfigurationen von Cloud TPU v5e finden Sie unter TPU v5e.

Mehr erfahren

In den folgenden Abschnitten wird beschrieben, wie Sie mit TPU v5e beginnen.

Anfragekontingent

Sie benötigen ein Kontingent, um TPU v5e für das Training zu verwenden. Es gibt verschiedene Kontingenttypen für On-Demand-TPUs, reservierte TPUs und TPU-Spot-VMs. Wenn Sie Ihre TPU v5e für die Inferenz verwenden, sind separate Kontingente erforderlich. Weitere Informationen zu Kontingenten finden Sie unter Kontingente. Wenn Sie ein TPU v5e-Kontingent anfordern möchten, wenden Sie sich an den Cloud-Vertrieb.

Google Cloud-Konto und -Projekt erstellen

Sie benötigen ein Google Cloud-Konto und ein Projekt, um Cloud TPU verwenden zu können. Weitere Informationen finden Sie unter Cloud TPU-Umgebung einrichten.

Cloud TPU erstellen

Als Best Practice wird empfohlen, Cloud TPU v5-Ressourcen mit dem Befehl queued-resource create als Ressourcen in der Warteschlange zu provisionieren. Weitere Informationen finden Sie unter In der Warteschlange befindliche Ressourcen verwalten.

Sie können auch die Create Node API (gcloud compute tpus tpu-vm create) verwenden, um Cloud TPU v5-Knoten bereitzustellen. Weitere Informationen finden Sie unter TPU-Ressourcen verwalten.

Weitere Informationen zu verfügbaren v5e-Konfigurationen für das Training finden Sie unter Cloud TPU v5e-Typen für das Training.

Framework einrichten

In diesem Abschnitt wird der allgemeine Einrichtungsprozess für das Training benutzerdefinierter Modelle mit JAX oder PyTorch mit TPU v5e beschrieben. TensorFlow wird von den TPU-Laufzeitversionen tpu-vm-tf-2.18.0-pjrt und tpu-vm-tf-2.18.0-pod-pjrt unterstützt.

Eine Anleitung zur Einrichtung von Inferenzen finden Sie unter Einführung in die v5e-Inferenz.

Einrichtung für JAX

Wenn Sie Slice-Formen mit mehr als 8 Chips haben, befinden sich mehrere VMs in einem Slice. In diesem Fall müssen Sie das Flag --worker=all verwenden, um die Installation auf allen TPU-VMs in einem einzigen Schritt auszuführen, ohne sich über SSH einzeln anzumelden:

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Beschreibung der Befehls-Flags

Variable Beschreibung
TPU_NAME Die vom Nutzer zugewiesene Text-ID der TPU, die erstellt wird, wenn die anstehende Ressourcenanfrage zugewiesen wird.
PROJECT_ID Name des Google Cloud-Projekts. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter Google Cloud-Projekt einrichten.
ZONE Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen.
Worker Die TPU-VM, die Zugriff auf die zugrunde liegenden TPUs hat.

Mit dem folgenden Befehl können Sie die Geräteanzahl prüfen. Die hier gezeigten Ausgaben wurden mit einem v5litepod-16-Stich erstellt. Mit diesem Code wird geprüft, ob alles korrekt installiert ist. Dazu wird überprüft, ob JAX die Cloud TPU-TensorCores erkennt und grundlegende Vorgänge ausführen kann:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

Die Ausgabe sollte in etwa so aussehen:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() gibt die Gesamtzahl der Chips im angegebenen Slice an. jax.local_device_count() gibt die Anzahl der Chips an, auf die eine einzelne VM in diesem Slice zugreifen kann.

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

Die Ausgabe sollte in etwa so aussehen:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

Sehen Sie sich die JAX-Anleitungen in diesem Dokument an, um mit dem Training mit JAX für v5e zu beginnen.

Einrichtung für PyTorch

Hinweis: V5e unterstützt nur die PJRT-Laufzeit und PyTorch 2.1 und höher verwendet PJRT als Standardlaufzeit für alle TPU-Versionen.

In diesem Abschnitt wird beschrieben, wie Sie PJRT unter v5e mit PyTorch/XLA mit Befehlen für alle Worker verwenden.

Abhängigkeiten installieren

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip3 install mkl mkl-include
      pip3 install tf-nightly tb-nightly tbp-nightly
      pip3 install numpy
      sudo apt-get install libopenblas-dev -y
      pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
      pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'

Wenn beim Installieren der Wheels für torch, torch_xla oder torchvision ein Fehler auftritt, z. B. pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end or semicolon (after name and no valid version specifier) torch==nightly+20230222, führen Sie mit diesem Befehl ein Downgrade aus:

pip3 install setuptools==62.1.0

Script mit PJRT ausführen

unset LD_PRELOAD

Im folgenden Beispiel wird mit einem Python-Script eine Berechnung auf einer v5e-VM ausgeführt:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker all \
   --command='
      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
      export PJRT_DEVICE=TPU_C_API
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      unset LD_PRELOAD
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

Dadurch wird eine Ausgabe generiert, die etwa so aussieht:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')

Sehen Sie sich die PyTorch-Tutorials in diesem Dokument an, um mit dem Training mit v5e mit PyTorch zu beginnen.

Löschen Sie die TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung. Wenn Sie eine in der Warteschlange befindliche Ressource löschen möchten, löschen Sie in zwei Schritten den Ausschnitt und dann die Ressource in der Warteschlange:

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Mit diesen beiden Schritten können Sie auch Ressourcenanfragen aus der Warteschlange entfernen, die sich im Status FAILED befinden.

JAX/FLAX-Beispiele

In den folgenden Abschnitten werden Beispiele zum Trainieren von JAX- und FLAX-Modellen auf einer TPU v5e beschrieben.

ImageNet mit v5e trainieren

In dieser Anleitung wird beschrieben, wie Sie ImageNet mit v5e mithilfe von gefälschten Eingabedaten trainieren. Wenn Sie echte Daten verwenden möchten, lesen Sie die README-Datei auf GitHub.

Einrichten

  1. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your_queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
  2. Erstellen Sie eine TPU-Ressource:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --${QUOTA_TYPE}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald die in der Warteschlange befindliche Ressource den Status ACTIVE hat:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Wenn sich die QueuedResource im Status ACTIVE befindet, sieht die Ausgabe in etwa so aus:

     state: ACTIVE
    
  3. Installieren Sie die neueste Version von JAX und jaxlib:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Klonen Sie das ImageNet-Modell und installieren Sie die entsprechenden Anforderungen:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/google/flax.git && cd flax/examples/imagenet && pip install -r requirements.txt && pip install flax==0.7.4'
    
  5. Zum Generieren von Fake-Daten benötigt das Modell Informationen zu den Dimensionen des Datensatzes. Diese Informationen können aus den Metadaten des ImageNet-Datasets abgerufen werden:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='mkdir -p $HOME/flax/.tfds/metadata/imagenet2012/5.1.0 && curl https://raw.githubusercontent.com/tensorflow/datasets/v4.4.0/tensorflow_datasets/testing/metadata/imagenet2012/5.1.0/dataset_info.json --output $HOME/flax/.tfds/metadata/imagenet2012/5.1.0/dataset_info.json'
    

Modell trainieren

Sobald Sie alle vorherigen Schritte ausgeführt haben, können Sie das Modell trainieren.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd flax/examples/imagenet && JAX_PLATFORMS=tpu python3 imagenet_fake_data_benchmark.py'

TPU und in die Warteschlange gestellte Ressource löschen

Löschen Sie die TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Hugging Face FLAX-Modelle

In FLAX implementierte Hugging Face-Modelle funktionieren ohne zusätzliche Anpassungen auf Cloud TPU v5e. In diesem Abschnitt finden Sie eine Anleitung zum Ausführen beliebter Modelle.

ViT auf Imagenette trainieren

In dieser Anleitung erfahren Sie, wie Sie das Vision Transformer-Modell (ViT) von HuggingFace mit dem Fast AI-Dataset Imagenette auf Cloud TPU v5e trainieren.

Das ViT-Modell war das erste, mit dem ein Transformer-Encoder erfolgreich auf ImageNet trainiert wurde. Im Vergleich zu Convolutional Neural Networks erzielte es hervorragende Ergebnisse. Weitere Informationen finden Sie in den folgenden Ressourcen:

Einrichten

  1. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your_queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
  2. Erstellen Sie eine TPU-Ressource:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --${QUOTA_TYPE}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich die in die Warteschlange gestellte Ressource im Status ACTIVE befindet:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Wenn sich die erwartete Ressource im Status ACTIVE befindet, sieht die Ausgabe in etwa so aus:

     state: ACTIVE
    
  3. Installieren Sie JAX und die zugehörige Bibliothek:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Laden Sie das Repository von Hugging Face herunter und installieren Sie die erforderlichen Pakete:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.18.0 && pip install -r examples/flax/vision/requirements.txt'
    
  5. Laden Sie das Imagenette-Dataset herunter:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
    

Modell trainieren

Trainieren Sie das Modell mit einem vorab zugeordneten Puffer von 4 GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'

TPU und in die Warteschlange gestellte Ressource löschen

Löschen Sie die TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Benchmarking-Ergebnisse für ViT

Das Trainingsskript wurde auf v5litepod-4, v5litepod-16 und v5litepod-64 ausgeführt. In der folgenden Tabelle sind die Durchsätze für verschiedene Beschleunigertypen aufgeführt.

Beschleunigertyp v5litepod-4 v5litepod-16 v5litepod-64
Epoche 3 3 3
Globale Batchgröße 32 128 512
Durchsatz (Beispiele/Sek.) 263,40 429,34 470,71

Diffusion mit Pokémon trainieren

In dieser Anleitung erfahren Sie, wie Sie das Stable Diffusion-Modell von HuggingFace mit dem Pokémon-Dataset auf Cloud TPU v5e trainieren.

Das Stable Diffusion-Modell ist ein latentes Text-zu-Bild-Modell, das fotorealistische Bilder aus beliebigen Texteingaben generiert. Weitere Informationen finden Sie in den folgenden Ressourcen:

Einrichten

  1. Richten Sie einen Speicher-Bucket für die Modellausgabe ein.

    gcloud storage buckets create gs://your_bucket 
    --project=your_project
    --location=us-west1
    export GCS_BUCKET_NAME=your_bucket

  2. Umgebungsvariablen erstellen

    export GCS_BUCKET_NAME=your_bucket
    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west1-c
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d

    Beschreibung der Befehls-Flags

    Variable Beschreibung
    GCS_BUCKET_NAME Wird in der Google Cloud Console unter „Cloud Storage“ -> „Buckets“ angezeigt
    PROJECT_ID Name des Google Cloud-Projekts. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter Google Cloud-Projekt einrichten.
    ACCELERATOR_TYPE Informationen zu Ihrer TPU-Version finden Sie auf der Seite TPU-Versionen.
    ZONE Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen.
    RUNTIME_VERSION Verwenden Sie „v2-alpha-tpuv5“ für die RUNTIME_VERSION.
    SERVICE_ACCOUNT Das ist die Adresse Ihres Dienstkontos. Sie finden sie in der Google Cloud Console unter „IAM“ -> „Dienstkonten“. Beispiel: tpu-service-account@myprojectID.iam.gserviceaccount.com
    TPU_NAME Die vom Nutzer zugewiesene Text-ID der TPU, die erstellt wird, wenn die anstehende Ressourcenanfrage zugewiesen wird.
    QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der anstehenden Ressourcenanfrage. Informationen zu Ressourcen in der Warteschlange finden Sie im Dokument Ressourcen in der Warteschlange.
    QUOTA_TYPE Kann reserved oder spot sein. Wenn keines dieser Elemente angegeben ist, wird für QUOTA_TYPE standardmäßig on-demand verwendet. Informationen zu den verschiedenen Arten von Kontingenten, die von Cloud TPU unterstützt werden, finden Sie unter Kontingente.
    VALID_UNTIL_DURATION Die Dauer, für die die Anfrage gültig ist. Informationen zu den verschiedenen gültigen Zeiträumen finden Sie unter Ressourcen in der Warteschlange.
  3. Erstellen Sie eine TPU-Ressource:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --${QUOTA_TYPE}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich die in die Warteschlange gestellte Ressource im Status ACTIVE befindet:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Wenn sich die erwartete Ressource im Status ACTIVE befindet, sieht die Ausgabe in etwa so aus:

     state: ACTIVE
    
  4. Installieren Sie JAX und die zugehörige Bibliothek.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  5. Laden Sie das Repository von HuggingFace und die Installationsanforderungen herunter.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install tensorflow==2.18.0 clu && pip install -U -r examples/text_to_image/requirements_flax.txt'
    

Modell trainieren

Trainieren Sie das Modell mit einem vorab zugeordneten Puffer von 4 GB.

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project ${PROJECT_ID} --worker=all --command="
   git clone https://github.com/google/maxdiffusion
   cd maxdiffusion
   git reset --hard 57629bcf4fa32fe5a57096b60b09f41f2fa5c35d # This identifies the GitHub commit to use.
   pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
   pip3 install -r requirements.txt
   pip3 install .
   export LIBTPU_INIT_ARGS=""
   python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name=your_run base_output_directory=gs://${GCS_BUCKET_NAME}/ enable_profiler=False"

TPU und in die Warteschlange gestellte Ressource löschen

Löschen Sie die TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Benchmarking-Ergebnisse für die Verbreitung

Das Trainingsskript wurde auf v5litepod-4, v5litepod-16 und v5litepod-64 ausgeführt. Die folgende Tabelle enthält die Durchsätze.

Beschleunigertyp v5litepod-4 v5litepod-16 v5litepod-64
Trainingsschritt 1500 1500 1500
Globale Batchgröße 32 64 128
Durchsatz (Beispiele/Sek.) 36,53 43,71 49.36

GPT-2 mit dem OSCAR-Dataset trainieren

In dieser Anleitung erfahren Sie, wie Sie das GPT2-Modell von HuggingFace mit dem OSCAR-Dataset auf Cloud TPU v5e trainieren.

GPT2 ist ein Transformer-Modell, das ohne menschliches Labeling auf Rohtexten vortrainiert wurde. Es wurde trainiert, das nächste Wort in Sätzen vorherzusagen. Weitere Informationen finden Sie in den folgenden Ressourcen:

Einrichten

  1. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
  2. Erstellen Sie eine TPU-Ressource:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --${QUOTA_TYPE}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich die in der Warteschlange befindliche Ressource im Status ACTIVE befindet:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Wenn sich die erwartete Ressource im Status ACTIVE befindet, sieht die Ausgabe in etwa so aus:

     state: ACTIVE
    
  3. Installieren Sie JAX und die zugehörige Bibliothek.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Laden Sie das Repository von HuggingFace herunter und erfüllen Sie die Installationsanforderungen.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install TensorFlow && pip install -r examples/flax/language-modeling/requirements.txt'
    
  5. Konfigurationen zum Trainieren des Modells herunterladen

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers/examples/flax/language-modeling && gcloud storage cp gs://cloud-tpu-tpuvm-artifacts/v5litepod-preview/jax/gpt . --recursive'
    

Modell trainieren

Trainieren Sie das Modell mit einem vorab zugeordneten Puffer von 4 GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers/examples/flax/language-modeling && TPU_PREMAPPED_BUFFER_SIZE=4294967296 JAX_PLATFORMS=tpu python3 run_clm_flax.py --output_dir=./gpt --model_type=gpt2 --config_name=./gpt --tokenizer_name=./gpt --dataset_name=oscar --dataset_config_name=unshuffled_deduplicated_no --do_train --do_eval --block_size=512 --per_device_train_batch_size=4 --per_device_eval_batch_size=4 --learning_rate=5e-3 --warmup_steps=1000 --adam_beta1=0.9 --adam_beta2=0.98 --weight_decay=0.01 --overwrite_output_dir --num_train_epochs=3 --logging_steps=500 --eval_steps=2500'

TPU und in die Warteschlange gestellte Ressource löschen

Löschen Sie die TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Benchmark-Ergebnisse für GPT2

Das Trainingsskript wurde auf v5litepod-4, v5litepod-16 und v5litepod-64 ausgeführt. Die folgende Tabelle enthält die Durchsätze.

v5litepod-4 v5litepod-16 v5litepod-64
Epoche 3 3 3
Globale Batchgröße 64 64 64
Durchsatz (Beispiele/Sek.) 74,60 72,97 72,62

PyTorch/XLA

In den folgenden Abschnitten werden Beispiele zum Trainieren von PyTorch-/XLA-Modellen auf TPU v5e beschrieben.

ResNet mit der PJRT-Laufzeit trainieren

PyTorch/XLA wird von PyTorch 2.0 und höher von XRT zu PjRt migriert. Hier finden Sie die aktualisierte Anleitung zum Einrichten von v5e für PyTorch/XLA-Trainingsarbeitslasten.

Einrichten
  1. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=tpu-name
    export QUEUED_RESOURCE_ID=queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
  2. Erstellen Sie eine TPU-Ressource:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --{QUOTA_TYPE}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich die QueuedResource im Status ACTIVE befindet:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Wenn sich die erwartete Ressource im Status ACTIVE befindet, sieht die Ausgabe in etwa so aus:

     state: ACTIVE
    
  3. Torch-/XLA-spezifische Abhängigkeiten installieren

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='
          sudo apt-get update -y
          sudo apt-get install libomp5 -y
          pip3 install mkl mkl-include
          pip3 install tf-nightly tb-nightly tbp-nightly
          pip3 install numpy
          sudo apt-get install libopenblas-dev -y
          pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
          pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
    
ResNet-Modell trainieren
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      date
      export PJRT_DEVICE=TPU_C_API
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export XLA_USE_BF16=1
      export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      git clone https://github.com/pytorch/xla.git
      cd xla/
      git reset --hard caf5168785c081cd7eb60b49fe4fffeb894c39d9
      python3 test/test_train_mp_imagenet.py --model=resnet50  --fake_data --num_epochs=1 —num_workers=16  --log_steps=300 --batch_size=64 --profile'

TPU und in die Warteschlange gestellte Ressource löschen

Löschen Sie die TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
Benchmark-Ergebnis

In der folgenden Tabelle sind die Benchmarkdurchsätze aufgeführt.

Beschleunigertyp Durchsatz (Beispiele/Sekunde)
v5litepod-4 4.240 Ex/s
v5litepod-16 10.810 ex/s
v5litepod-64 46.154 ex/s

GPT2 mit v5e trainieren

In dieser Anleitung erfahren Sie, wie Sie GPT2 in v5e mit dem Repository von HuggingFace auf PyTorch/XLA mit dem Wikitext-Dataset ausführen.

Einrichten

  1. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
  2. Erstellen Sie eine TPU-Ressource:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --${QUOTA_TYPE}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald die QueuedResource den Status ACTIVE hat:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Wenn sich die erwartete Ressource im Status ACTIVE befindet, sieht die Ausgabe in etwa so aus:

    state: ACTIVE
    
  3. Installieren Sie die PyTorch-/XLA-Abhängigkeiten.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='
          sudo apt-get -y update
          sudo apt install -y libopenblas-base
          pip3 install torchvision
          pip3 uninstall -y torch
          pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
          pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
    
  4. Laden Sie das Repository von HuggingFace herunter und erfüllen Sie die Installationsanforderungen.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='
          git clone https://github.com/pytorch/xla.git
          pip install --upgrade accelerate
          git clone https://github.com/huggingface/transformers.git
          cd transformers
          git checkout ebdb185befaa821304d461ed6aa20a17e4dc3aa2
          pip install .
          git log -1
          pip install datasets evaluate scikit-learn
          '
    
  5. Konfigurationen des vortrainierten Modells herunterladen

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='
          gcloud storage cp gs://cloud-tpu-tpuvm-artifacts/config/xl-ml-test/pytorch/gpt2/my_config_2.json transformers/examples/pytorch/language-modeling/ --recursive
          gcloud storage cp gs://cloud-tpu-tpuvm-artifacts/config/xl-ml-test/pytorch/gpt2/fsdp_config.json transformers/examples/pytorch/language-modeling/'
    

Modell trainieren

Trainieren Sie das 2B-Modell mit einer Batchgröße von 16.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU_C_API
      cd transformers/
      export LD_LIBRARY_PATH=/usr/local/lib/
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      python3 examples/pytorch/xla_spawn.py \
         --num_cores=4 \
         examples/pytorch/language-modeling/run_clm.py \
         --num_train_epochs=3 \
         --dataset_name=wikitext \
         --dataset_config_name=wikitext-2-raw-v1 \
         --per_device_train_batch_size=16 \
         --per_device_eval_batch_size=16 \
         --do_train \
         --do_eval \
         --logging_dir=./tensorboard-metrics \
         --cache_dir=./cache_dir \
         --output_dir=/tmp/test-clm \
         --overwrite_output_dir \
         --cache_dir=/tmp \
         --config_name=examples/pytorch/language-modeling/my_config_2.json \
         --tokenizer_name=gpt2 \
         --block_size=1024 \
         --optim=adafactor \
         --adafactor=true \
         --save_strategy=no \
         --logging_strategy=no \
         --fsdp=full_shard \
         --fsdp_config=examples/pytorch/language-modeling/fsdp_config.json'

TPU und in die Warteschlange gestellte Ressource löschen

Löschen Sie die TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Benchmark-Ergebnis

Das Trainingsskript wurde auf v5litepod-4, v5litepod-16 und v5litepod-64 ausgeführt. In der folgenden Tabelle sind die Benchmark-Durchsätze für verschiedene Beschleunigertypen aufgeführt.

v5litepod-4 v5litepod-16 v5litepod-64
Epoche 3 3 3
config 600 Mio. 2 Milliarden 16 B
Globale Batchgröße 64 128 256
Durchsatz (Beispiele/Sek.) 66 77 31

ViT auf v5e trainieren

In dieser Anleitung erfahren Sie, wie Sie VIT auf v5e mit dem Repository von HuggingFace in PyTorch/XLA auf dem CIFAR10-Dataset ausführen.

Einrichten

  1. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=tpu-name
    export QUEUED_RESOURCE_ID=queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
  2. Erstellen Sie eine TPU-Ressource:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --${QUOTA_TYPE}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald die QueuedResource den Status ACTIVE hat:

     gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    Wenn sich die erwartete Ressource im Status ACTIVE befindet, sieht die Ausgabe in etwa so aus:

     state: ACTIVE
    
  3. PyTorch/XLA-Abhängigkeiten installieren

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all
       --command='
          sudo apt-get update -y
          sudo apt-get install libomp5 -y
          pip3 install mkl mkl-include
          pip3 install tf-nightly tb-nightly tbp-nightly
          pip3 install numpy
          sudo apt-get install libopenblas-dev -y
          pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
          pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
    
  4. Laden Sie das Repository von HuggingFace herunter und erfüllen Sie die Installationsanforderungen.

       gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="
          git clone https://github.com/suexu1025/transformers.git vittransformers; \
          cd vittransformers; \
          pip3 install .; \
          pip3 install datasets; \
          wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
    

Modell trainieren

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU_C_API
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export TF_CPP_MIN_LOG_LEVEL=0
      export XLA_USE_BF16=1
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      cd vittransformers
      python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
      --remove_unused_columns=False \
      --label_names=pixel_values \
      --mask_ratio=0.75 \
      --norm_pix_loss=True \
      --do_train=true \
      --do_eval=true \
      --base_learning_rate=1.5e-4 \
      --lr_scheduler_type=cosine \
      --weight_decay=0.05 \
      --num_train_epochs=3 \
      --warmup_ratio=0.05 \
      --per_device_train_batch_size=8 \
      --per_device_eval_batch_size=8 \
      --logging_strategy=steps \
      --logging_steps=30 \
      --evaluation_strategy=epoch \
      --save_strategy=epoch \
      --load_best_model_at_end=True \
      --save_total_limit=3 \
      --seed=1337 \
      --output_dir=MAE \
      --overwrite_output_dir=true \
      --logging_dir=./tensorboard-metrics \
      --tpu_metrics_debug=true'

TPU und in die Warteschlange gestellte Ressource löschen

Löschen Sie die TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.

gcloud compute tpus tpu-vm delete ${TPU_NAME}
   --project=${PROJECT_ID}
   --zone=${ZONE}
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID}
   --project=${PROJECT_ID}
   --zone=${ZONE}
   --quiet

Benchmark-Ergebnis

In der folgenden Tabelle sind die Benchmark-Durchsätze für verschiedene Arten von Beschleunigern aufgeführt.

v5litepod-4 v5litepod-16 v5litepod-64
Epoche 3 3 3
Globale Batchgröße 32 128 512
Durchsatz (Beispiele/Sek.) 201 657 2.844

TensorFlow 2.x

In den folgenden Abschnitten werden Beispiele für das Trainieren von TensorFlow 2.x-Modellen auf TPU v5e beschrieben.

Resnet auf einer TPU v5e mit einem Host trainieren

In dieser Anleitung wird beschrieben, wie Sie ImageNet auf v5litepod-4 oder v5litepod-8 mit einem fiktiven Dataset trainieren. Wenn Sie ein anderes Dataset verwenden möchten, lesen Sie den Hilfeartikel Dataset vorbereiten.

Einrichten

  1. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your-project-ID
    export ACCELERATOR_TYPE=v5litepod-4
    export ZONE=us-east1-c
    export RUNTIME_VERSION=tpu-vm-tf-2.15.0-pjrt
    export TPU_NAME=your-tpu-name
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export QUOTA_TYPE=quota-type

    ACCELERATOR_TYPE kann entweder v5litepod-4 oder v5litepod-8 sein.

  2. Erstellen Sie eine TPU-Ressource:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --${QUOTA_TYPE}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich die in die Warteschlange gestellte Ressource im Status ACTIVE befindet. Verwenden Sie den folgenden Befehl, um den Status der in der Warteschlange befindlichen Ressource zu prüfen:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    
  3. Über SSH eine Verbindung zur TPU herstellen

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    
  4. Einige Umgebungsvariablen festlegen

    export MODELS_REPO=/usr/share/tpu/models
    export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}"
    export MODEL_DIR=gcp-directory-to-store-model
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export NEXT_PLUGGABLE_DEVICE_USE_C_API=true
    export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
  5. Wechseln Sie in das Repository-Verzeichnis für Modelle und installieren Sie die erforderlichen Pakete.

    cd ${MODELS_REPO} && git checkout r2.15.0
    pip install -r official/requirements.txt
    

Modell trainieren

Führen Sie das Trainingsskript aus.

python3 official/vision/train.py \
   --tpu=local \
   --experiment=resnet_imagenet \
   --mode=train_and_eval \
   --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
   --model_dir=${MODEL_DIR} \
   --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"

TPU und in die Warteschlange gestellte Ressource löschen

  1. TPU löschen

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --quiet
    
  2. In die Warteschlange gestellte Ressourcenanfrage löschen

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --quiet
    

Resnet auf einem v5e-System mit mehreren Hosts trainieren

In dieser Anleitung wird beschrieben, wie Sie ImageNet mit einem fiktiven Dataset auf v5litepod-16 oder höher trainieren. Wenn Sie ein anderes Dataset verwenden möchten, lesen Sie den Hilfeartikel Dataset vorbereiten.

  1. Erstellen Sie Umgebungsvariablen:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-east1-c
    export RUNTIME_VERSION=tpu-vm-tf-2.15.0-pod-pjrt
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export QUOTA_TYPE=quota-type

    ACCELERATOR_TYPE kann entweder v5litepod-16 oder größer sein.

  2. Erstellen Sie eine TPU-Ressource:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --${QUOTA_TYPE}
    

    Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich die in der Warteschlange befindliche Ressource im Status ACTIVE befindet. Verwenden Sie den folgenden Befehl, um den Status der in der Warteschlange befindlichen Ressource zu prüfen:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    
  3. Über SSH eine Verbindung zu Ihrer TPU (Worker 0) herstellen

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    
  4. Einige Umgebungsvariablen festlegen

    export MODELS_REPO=/usr/share/tpu/models
    export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}"
    export MODEL_DIR=gcp-directory-to-store-model
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export TPU_LOAD_LIBRARY=0
    export TPU_NAME=your_tpu_name
  5. Wechseln Sie in das Repository-Verzeichnis für Modelle und installieren Sie die erforderlichen Pakete.

     cd $MODELS_REPO && git checkout r2.15.0
     pip install -r official/requirements.txt
    

Modell trainieren

Führen Sie das Trainingsskript aus.

python3 official/vision/train.py \
   --tpu=${TPU_NAME} \
   --experiment=resnet_imagenet \
   --mode=train_and_eval \
   --model_dir=${MODEL_DIR} \
   --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*, task.validation_data.input_path=${DATA_DIR}/validation*"

TPU und in die Warteschlange gestellte Ressource löschen

  1. TPU löschen

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --quiet
    
  2. In die Warteschlange gestellte Ressourcenanfrage löschen

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --quiet