Modell mit TPU v5e trainieren
Mit einem kleineren Footprint von 256 Chips pro Pod ist TPU v5e für das Training, die Feinabstimmung und die Bereitstellung von Transformer-, Text-zu-Bild- und CNN-Modellen (Convolutional Neural Network) optimiert. Weitere Informationen zur Verwendung von Cloud TPU v5e für die Bereitstellung finden Sie unter Inferenz mit v5e.
Weitere Informationen zu Cloud TPU v5e-Hardware und -Konfigurationen finden Sie unter TPU v5e.
Jetzt starten
In den folgenden Abschnitten wird beschrieben, wie Sie mit der Verwendung von TPU v5e beginnen.
Anfragekontingent
Sie benötigen Kontingent, um TPU v5e für das Training zu verwenden. Es gibt verschiedene Kontingenttypen für On-Demand-TPUs, reservierte TPUs und TPU-Spot-VMs. Wenn Sie Ihre TPU v5e für Inferenz verwenden, sind separate Kontingente erforderlich. Weitere Informationen zu Kontingenten finden Sie unter Kontingente. Wenn Sie ein TPU v5e-Kontingent anfordern möchten, wenden Sie sich an den Cloud-Vertrieb.
Google Cloud -Konto und -Projekt erstellen
Sie benötigen ein Google Cloud -Konto und ein Projekt, um Cloud TPU zu verwenden. Weitere Informationen finden Sie unter Cloud TPU-Umgebung einrichten.
Cloud TPU erstellen
Als Best Practice gilt, Cloud TPU v5es als in die Warteschlange gestellte Ressourcen mit dem Befehl queued-resource create
bereitzustellen. Weitere Informationen finden Sie unter In die Warteschlange gestellte Ressourcen verwalten.
Sie können auch die Create Node API (gcloud compute tpus tpu-vm create
) verwenden, um Cloud TPU v5e-Knoten bereitzustellen. Weitere Informationen finden Sie unter TPU-Ressourcen verwalten.
Weitere Informationen zu den verfügbaren v5e-Konfigurationen für das Training finden Sie unter Cloud TPU v5e-Typen für das Training.
Framework einrichten
In diesem Abschnitt wird die allgemeine Einrichtung für das benutzerdefinierte Modelltraining mit JAX oder PyTorch mit TPU v5e beschrieben.
Eine Anleitung zur Einrichtung der Inferenz finden Sie unter Einführung in die Inferenz in Version 5e.
Definieren Sie einige Umgebungsvariablen:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id
Einrichtung für JAX
Wenn Sie Slice-Formen mit mehr als 8 Chips haben, sind mehrere VMs in einem Slice vorhanden. In diesem Fall müssen Sie das Flag --worker=all
verwenden, um die Installation in einem einzigen Schritt auf allen TPU-VMs auszuführen, ohne sich über SSH auf jeder einzelnen anzumelden:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Beschreibung der Befehls-Flags
Variable | Beschreibung |
TPU_NAME | Die vom Nutzer zugewiesene Text-ID der TPU, die erstellt wird, wenn die in die Warteschlange gestellte Ressourcenanfrage zugewiesen wird. |
PROJECT_ID | Google Cloud Projektname Vorhandenes Projekt verwenden oder neues Projekt erstellen unter Google Cloud -Projekt einrichten |
ZONE | Eine Liste der unterstützten Zonen finden Sie im Dokument TPU-Regionen und -Zonen. |
Worker | Die TPU-VM, die Zugriff auf die zugrunde liegenden TPUs hat. |
Mit dem folgenden Befehl können Sie die Anzahl der Geräte prüfen. Die hier gezeigten Ausgaben wurden mit einem v5litepod-16-Slice erstellt. Mit diesem Code wird getestet, ob alles korrekt installiert ist. Dazu wird geprüft, ob JAX die Cloud TPU-TensorCores sieht und grundlegende Vorgänge ausführen kann:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
Die Ausgabe sollte in etwa so aussehen:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4
jax.device_count()
gibt die Gesamtzahl der Chips im angegebenen Slice an.
jax.local_device_count()
gibt die Anzahl der Chips an, auf die eine einzelne VM in diesem Slice zugreifen kann.
# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
Die Ausgabe sollte in etwa so aussehen:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
In diesem Dokument finden Sie JAX-Anleitungen, mit denen Sie mit dem Training von v5e-Modellen mit JAX beginnen können.
Einrichtung für PyTorch
Beachten Sie, dass v5e nur die PJRT-Laufzeit unterstützt. In PyTorch 2.1+ wird PJRT als Standardlaufzeit für alle TPU-Versionen verwendet.
In diesem Abschnitt wird beschrieben, wie Sie PJRT auf v5e mit PyTorch/XLA mit Befehlen für alle Worker verwenden.
Abhängigkeiten installieren
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip install mkl mkl-include pip install tf-nightly tb-nightly tbp-nightly pip install numpy sudo apt-get install libopenblas-dev -y pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Ersetzen Sie PYTORCH_VERSION
durch die PyTorch-Version, die Sie verwenden möchten.
PYTORCH_VERSION
wird verwendet, um dieselbe Version für PyTorch/XLA anzugeben. Version 2.6.0 wird empfohlen.
Weitere Informationen zu Versionen von PyTorch und PyTorch/XLA finden Sie unter PyTorch – Erste Schritte und PyTorch/XLA-Releases.
Weitere Informationen zur Installation von PyTorch/XLA finden Sie unter PyTorch/XLA installieren.
Wenn beim Installieren der Wheels für torch
, torch_xla
oder torchvision
ein Fehler wie pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end
or semicolon (after name and no valid version specifier) torch==nightly+20230222
auftritt, führen Sie ein Downgrade der Version mit diesem Befehl durch:
pip3 install setuptools==62.1.0
Skript mit PJRT ausführen
unset LD_PRELOAD
Das folgende Beispiel zeigt, wie mit einem Python-Script eine Berechnung auf einer v5e-VM durchgeführt wird:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
unset LD_PRELOAD
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
Dadurch wird eine Ausgabe generiert, die etwa so aussieht:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
Mit den PyTorch-Tutorials in diesem Dokument können Sie mit dem Training von v5e-Modellen mit PyTorch beginnen.
Löschen Sie Ihre TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung. So löschen Sie eine Ressource in der Warteschlange:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Mit diesen beiden Schritten können Sie auch in die Warteschlange gestellte Ressourcenanfragen entfernen, die sich im Status FAILED
befinden.
JAX/FLAX-Beispiele
In den folgenden Abschnitten finden Sie Beispiele für das Trainieren von JAX- und FLAX-Modellen auf TPU v5e.
ImageNet auf v5e trainieren
In dieser Anleitung wird beschrieben, wie Sie ImageNet auf v5e mit gefälschten Eingabedaten trainieren. Wenn Sie echte Daten verwenden möchten, lesen Sie die README-Datei auf GitHub.
Einrichten
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Beschreibungen von Umgebungsvariablen
Variable Beschreibung PROJECT_ID
Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. TPU_NAME
Der Name der TPU. ZONE
Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen. ACCELERATOR_TYPE
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen. RUNTIME_VERSION
Die Softwareversion der Cloud TPU. SERVICE_ACCOUNT
Die E‑Mail-Adresse für Ihr Dienstkonto. Sie finden sie in der Google Cloud Console auf der Seite „Dienstkonten“. Beispiel:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
Die vom Nutzer zugewiesene Text-ID der in die Warteschlange eingereihten Ressourcenanfrage. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich Ihre in die Warteschlange gestellte Ressource im Status
ACTIVE
befindet:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Wenn sich die QueuedResource im Status
ACTIVE
befindet, sieht die Ausgabe in etwa so aus:state: ACTIVE
Installieren Sie die neueste Version von JAX und jaxlib:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Klonen Sie das ImageNet-Modell und installieren Sie die entsprechenden Anforderungen:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
Um gefälschte Daten zu generieren, benötigt das Modell Informationen zu den Dimensionen des Datasets. Diese Informationen können aus den Metadaten des ImageNet-Datasets abgerufen werden:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
Modell trainieren
Sobald alle vorherigen Schritte abgeschlossen sind, können Sie das Modell trainieren.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"
TPU und in die Warteschlange gestellte Ressource löschen
Löschen Sie Ihre TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Hugging Face FLAX-Modelle
Hugging Face-Modelle, die in FLAX implementiert sind, funktionieren sofort auf Cloud TPU v5e. In diesem Abschnitt finden Sie Anleitungen zum Ausführen beliebter Modelle.
ViT auf Imagenette trainieren
In dieser Anleitung erfahren Sie, wie Sie das Vision Transformer-Modell (ViT) von HuggingFace mit dem Fast AI-Dataset Imagenette auf Cloud TPU v5e trainieren.
Das ViT-Modell war das erste, mit dem ein Transformer-Encoder erfolgreich auf ImageNet trainiert wurde und das im Vergleich zu Convolutional Networks hervorragende Ergebnisse lieferte. Weitere Informationen finden Sie in den folgenden Ressourcen:
Einrichten
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Beschreibungen von Umgebungsvariablen
Variable Beschreibung PROJECT_ID
Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. TPU_NAME
Der Name der TPU. ZONE
Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen. ACCELERATOR_TYPE
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen. RUNTIME_VERSION
Die Softwareversion der Cloud TPU. SERVICE_ACCOUNT
Die E‑Mail-Adresse für Ihr Dienstkonto. Sie finden sie in der Google Cloud Console auf der Seite „Dienstkonten“. Beispiel:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
Die vom Nutzer zugewiesene Text-ID der in die Warteschlange eingereihten Ressourcenanfrage. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich Ihre in die Warteschlange gestellte Ressource im Status
ACTIVE
befindet:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Wenn sich die in die Warteschlange gestellte Ressource im Status
ACTIVE
befindet, sieht die Ausgabe in etwa so aus:state: ACTIVE
Installieren Sie JAX und die zugehörige Bibliothek:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Laden Sie das Hugging Face-Repository herunter und installieren Sie die Anforderungen:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
Laden Sie das Imagenette-Dataset herunter:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
Modell trainieren
Trainieren Sie das Modell mit einem vorab zugeordneten Puffer von 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'
TPU und in die Warteschlange gestellte Ressource löschen
Löschen Sie Ihre TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Benchmark-Ergebnisse für ViT
Das Trainingsskript wurde auf v5litepod-4, v5litepod-16 und v5litepod-64 ausgeführt. In der folgenden Tabelle sind die Durchsätze für verschiedene Beschleunigertypen aufgeführt.
Beschleunigertyp | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Epoche | 3 | 3 | 3 |
Globale Batchgröße | 32 | 128 | 512 |
Durchsatz (Beispiele/s) | 263,40 | 429,34 | 470,71 |
Diffusion auf Pokémon trainieren
In dieser Anleitung wird gezeigt, wie Sie das Stable Diffusion-Modell von HuggingFace mit dem Pokémon-Dataset auf Cloud TPU v5e trainieren.
Das Stable Diffusion-Modell ist ein latentes Text-zu-Bild-Modell, das fotorealistische Bilder aus beliebigen Texteingaben generiert. Weitere Informationen finden Sie in den folgenden Ressourcen:
Einrichten
Legen Sie eine Umgebungsvariable für den Namen Ihres Speicher-Buckets fest:
export GCS_BUCKET_NAME=your_bucket_name
Speicher-Bucket für die Modellausgabe einrichten:
gcloud storage buckets create gs://GCS_BUCKET_NAME \ --project=your_project \ --location=us-west1
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west1-c export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Beschreibungen von Umgebungsvariablen
Variable Beschreibung PROJECT_ID
Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. TPU_NAME
Der Name der TPU. ZONE
Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen. ACCELERATOR_TYPE
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen. RUNTIME_VERSION
Die Softwareversion der Cloud TPU. SERVICE_ACCOUNT
Die E‑Mail-Adresse für Ihr Dienstkonto. Sie finden sie in der Google Cloud Console auf der Seite „Dienstkonten“. Beispiel:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
Die vom Nutzer zugewiesene Text-ID der in die Warteschlange eingereihten Ressourcenanfrage. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich die in die Warteschlange gestellte Ressource im Status
ACTIVE
befindet:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Wenn sich die in die Warteschlange gestellte Ressource im Status
ACTIVE
befindet, sieht die Ausgabe in etwa so aus:state: ACTIVE
Installieren Sie JAX und die zugehörige Bibliothek.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Laden Sie das HuggingFace-Repository herunter und installieren Sie die Anforderungen.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
Modell trainieren
Trainieren Sie das Modell mit einem vorab zugeordneten Puffer von 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
git clone https://github.com/google/maxdiffusion
cd maxdiffusion
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
pip3 install gcsfs
export LIBTPU_INIT_ARGS=''
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"
Bereinigen
Löschen Sie am Ende der Sitzung Ihre TPU, die in die Warteschlange gestellte Ressource und den Cloud Storage-Bucket.
Löschen Sie Ihre TPU:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Löschen Sie die in die Warteschlange gestellte Ressource:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Löschen Sie den Cloud Storage-Bucket:
gcloud storage rm -r gs://${GCS_BUCKET_NAME}
Benchmark-Ergebnisse für Diffusion
Das Trainingsskript wurde auf v5litepod-4, v5litepod-16 und v5litepod-64 ausgeführt. In der folgenden Tabelle sind die Durchsätze aufgeführt.
Beschleunigertyp | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Trainingsschritt | 1500 | 1500 | 1500 |
Globale Batchgröße | 32 | 64 | 128 |
Durchsatz (Beispiele/s) | 36,53 | 43,71 | 49.36 |
PyTorch/XLA
In den folgenden Abschnitten werden Beispiele für das Trainieren von PyTorch/XLA-Modellen auf TPU v5e beschrieben.
ResNet mit der PJRT-Laufzeit trainieren
PyTorch/XLA wird ab PyTorch 2.0 von XRT zu PjRt migriert. Hier finden Sie die aktualisierte Anleitung zum Einrichten von v5e für PyTorch/XLA-Trainingsarbeitslasten.
Einrichten
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Beschreibungen von Umgebungsvariablen
Variable Beschreibung PROJECT_ID
Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. TPU_NAME
Der Name der TPU. ZONE
Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen. ACCELERATOR_TYPE
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen. RUNTIME_VERSION
Die Softwareversion der Cloud TPU. SERVICE_ACCOUNT
Die E‑Mail-Adresse für Ihr Dienstkonto. Sie finden sie in der Google Cloud Console auf der Seite „Dienstkonten“. Beispiel:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
Die vom Nutzer zugewiesene Text-ID der in die Warteschlange eingereihten Ressourcenanfrage. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich die QueuedResource im Status
ACTIVE
befindet:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Wenn sich die in die Warteschlange gestellte Ressource im Status
ACTIVE
befindet, sieht die Ausgabe in etwa so aus:state: ACTIVE
Torch/XLA-spezifische Abhängigkeiten installieren
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
Ersetzen Sie
PYTORCH_VERSION
durch die PyTorch-Version, die Sie verwenden möchten.PYTORCH_VERSION
wird verwendet, um dieselbe Version für PyTorch/XLA anzugeben. Version 2.6.0 wird empfohlen.Weitere Informationen zu Versionen von PyTorch und PyTorch/XLA finden Sie unter PyTorch – Erste Schritte und PyTorch/XLA-Releases.
Weitere Informationen zur Installation von PyTorch/XLA finden Sie unter PyTorch/XLA installieren.
ResNet-Modell trainieren
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
date
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export XLA_USE_BF16=1
export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
git clone https://github.com/pytorch/xla.git
cd xla/
git checkout release-r2.6
python3 test/test_train_mp_imagenet.py --model=resnet50 --fake_data --num_epochs=1 —num_workers=16 --log_steps=300 --batch_size=64 --profile'
TPU und in die Warteschlange gestellte Ressource löschen
Löschen Sie Ihre TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Benchmark-Ergebnis
In der folgenden Tabelle sehen Sie die Benchmark-Durchsätze.
Beschleunigertyp | Durchsatz (Beispiele/Sekunde) |
v5litepod-4 | 4240 ex/s |
v5litepod-16 | 10.810 ex/s |
v5litepod-64 | 46.154 ex/s |
ViT auf v5e trainieren
In dieser Anleitung wird beschrieben, wie Sie VIT auf v5e mit dem HuggingFace-Repository für PyTorch/XLA auf dem cifar10-Dataset ausführen.
Einrichten
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Beschreibungen von Umgebungsvariablen
Variable Beschreibung PROJECT_ID
Ihre Google Cloud Projekt-ID. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues. TPU_NAME
Der Name der TPU. ZONE
Die Zone, in der die TPU-VM erstellt werden soll. Weitere Informationen zu unterstützten Zonen finden Sie unter TPU-Regionen und ‑Zonen. ACCELERATOR_TYPE
Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen. RUNTIME_VERSION
Die Softwareversion der Cloud TPU. SERVICE_ACCOUNT
Die E‑Mail-Adresse für Ihr Dienstkonto. Sie finden sie in der Google Cloud Console auf der Seite „Dienstkonten“. Beispiel:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
Die vom Nutzer zugewiesene Text-ID der in die Warteschlange eingereihten Ressourcenanfrage. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich Ihre QueuedResource im Status
ACTIVE
befindet:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Wenn sich die in die Warteschlange gestellte Ressource im Status
ACTIVE
befindet, sieht die Ausgabe in etwa so aus:state: ACTIVE
PyTorch/XLA-Abhängigkeiten installieren
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
Ersetzen Sie
PYTORCH_VERSION
durch die PyTorch-Version, die Sie verwenden möchten.PYTORCH_VERSION
wird verwendet, um dieselbe Version für PyTorch/XLA anzugeben. Version 2.6.0 wird empfohlen.Weitere Informationen zu Versionen von PyTorch und PyTorch/XLA finden Sie unter PyTorch – Erste Schritte und PyTorch/XLA-Releases.
Weitere Informationen zur Installation von PyTorch/XLA finden Sie unter PyTorch/XLA installieren.
Laden Sie das HuggingFace-Repository herunter und installieren Sie die Anforderungen.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=" git clone https://github.com/suexu1025/transformers.git vittransformers; \ cd vittransformers; \ pip3 install .; \ pip3 install datasets; \ wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
Modell trainieren
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export TF_CPP_MIN_LOG_LEVEL=0
export XLA_USE_BF16=1
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
cd vittransformers
python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
--remove_unused_columns=False \
--label_names=pixel_values \
--mask_ratio=0.75 \
--norm_pix_loss=True \
--do_train=true \
--do_eval=true \
--base_learning_rate=1.5e-4 \
--lr_scheduler_type=cosine \
--weight_decay=0.05 \
--num_train_epochs=3 \
--warmup_ratio=0.05 \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--logging_strategy=steps \
--logging_steps=30 \
--evaluation_strategy=epoch \
--save_strategy=epoch \
--load_best_model_at_end=True \
--save_total_limit=3 \
--seed=1337 \
--output_dir=MAE \
--overwrite_output_dir=true \
--logging_dir=./tensorboard-metrics \
--tpu_metrics_debug=true'
TPU und in die Warteschlange gestellte Ressource löschen
Löschen Sie Ihre TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Benchmark-Ergebnis
In der folgenden Tabelle sind die Benchmark-Durchsätze für verschiedene Beschleunigertypen aufgeführt.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Epoche | 3 | 3 | 3 |
Globale Batchgröße | 32 | 128 | 512 |
Durchsatz (Beispiele/s) | 201 | 657 | 2.844 |