Cloud TPU v5e-Training
Cloud TPU v5e ist der KI-Beschleuniger der neuesten Generation von Google Cloud. Mit einem weniger 256 Chips pro Pod, V5e ist für den höchsten Wert optimiert Produkt für Transformer, Text-to-Image und Convolutional Neural Network (CNN) Training, Feinabstimmung und Bereitstellung. Weitere Informationen zur Verwendung von Cloud TPU v5e für die Bereitstellung finden Sie unter Inferenz mit v5e.
Weitere Informationen zur Hardware und Konfiguration von Cloud TPU v5e-TPUs finden Sie unter TPU v5e:
Mehr erfahren
In den folgenden Abschnitten wird beschrieben, wie Sie mit TPU v5e beginnen.
Anfragekontingent
Sie benötigen ein Kontingent, um TPU v5e für das Training verwenden zu können. Es gibt verschiedene Kontingentarten für On-Demand-TPUs, reservierte TPUs und TPU-Spot-VMs. Wenn Sie Ihre TPU v5e für die Inferenz verwenden, sind separate Kontingente erforderlich. Weitere Informationen zu Kontingenten finden Sie unter Kontingente: Wenn Sie ein TPU v5e-Kontingent anfordern möchten, wenden Sie sich an Cloud Verkäufe.
Google Cloud-Konto und -Projekt erstellen
Sie benötigen ein Google Cloud-Konto und ein Projekt, um Cloud TPU verwenden zu können. Weitere Informationen finden Sie unter Cloud TPU-Umgebung einrichten.
Cloud TPU erstellen
Als Best Practice wird empfohlen, Cloud TPU v5-Ressourcen mit dem Befehl queued-resource create
als Ressourcen in der Warteschlange zu provisionieren. Weitere Informationen finden Sie unter In der Warteschlange verwalten
Ressourcen
Sie können Cloud TPU v5 auch mit der Create Node API (gcloud compute tpus tpu-vm create
) bereitstellen. Weitere Informationen finden Sie unter TPU verwalten
Ressourcen
Weitere Informationen zu verfügbaren v5e-Konfigurationen für das Training siehe Cloud TPU v5e-Typen für Schulung.
Framework einrichten
In diesem Abschnitt wird der allgemeine Einrichtungsprozess für das Training benutzerdefinierter Modelle mit JAX oder PyTorch mit TPU v5e beschrieben. TensorFlow-Unterstützung ist in der
tpu-vm-tf-2.17.0-pjrt
und tpu-vm-tf-2.17.0-pod-pjrt
TPU
Laufzeitversionen.
Eine Anleitung zur Einrichtung von Inferenzen finden Sie unter Einführung in die v5e-Inferenz.
Einrichtung für JAX
Wenn Sie Slice-Formen mit mehr als 8 Chips haben, befinden sich mehrere VMs in einem Slice. In diesem Fall müssen Sie das Flag --worker=all
verwenden, um die Installation auf allen TPU-VMs in einem einzigen Schritt auszuführen, ohne sich über SSH einzeln anzumelden:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Beschreibung der Befehls-Flags
Variable | Beschreibung |
TPU_NAME | Die vom Nutzer zugewiesene Text-ID der TPU, die erstellt wird, wenn die anstehende Ressourcenanfrage zugewiesen wird. |
PROJECT_ID | Name des Google Cloud-Projekts. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie ein neues unter Google Cloud-Projekt einrichten. |
ZONE | Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen. |
Worker | Die TPU-VM, die Zugriff auf die zugrunde liegenden TPUs hat. |
Mit dem folgenden Befehl können Sie die Geräteanzahl prüfen. Die hier gezeigten Ausgaben wurden mit einem v5litepod-16-Stich erstellt. Mit diesem Code wird getestet, durch Überprüfung, ob JAX die Cloud TPU TensorCores erkennt, und kann grundlegende Vorgänge ausführen:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
Die Ausgabe sollte in etwa so aussehen:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4
jax.device_count()
zeigt die Gesamtzahl der Chips im jeweiligen Segment an.
jax.local_device_count()
gibt die Anzahl der Chips an, auf die eine einzelne VM in diesem Slice zugreifen kann.
# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
Die Ausgabe sollte in etwa so aussehen:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
Sehen Sie sich die JAX-Anleitungen in diesem Dokument an, um mit dem Training mit v5e und JAX zu beginnen.
Einrichtung für PyTorch
Beachten Sie, dass Version 5e nur die PJRT-Laufzeit unterstützt. und PyTorch 2.1+ verwenden PJRT als Standardlaufzeit für alle TPU-Versionen.
In diesem Abschnitt wird beschrieben, wie Sie PJRT in v5e mit PyTorch/XLA mit Befehlen für alle Worker verwenden.
Abhängigkeiten installieren
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
sudo apt-get update -y
sudo apt-get install libomp5 -y
pip3 install mkl mkl-include
pip3 install tf-nightly tb-nightly tbp-nightly
pip3 install numpy
sudo apt-get install libopenblas-dev -y
pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
Wenn beim Einbau der Räder für torch
, torch_xla
oder
torchvision
„Gefällt mir“-Angabe
pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end
or semicolon (after name and no valid version specifier) torch==nightly+20230222
,
Führen Sie mit dem folgenden Befehl ein Downgrade Ihrer Version durch:
pip3 install setuptools==62.1.0
Skript mit PJRT ausführen
unset LD_PRELOAD
Im Folgenden finden Sie ein Beispiel, in dem ein Python-Skript für eine Berechnung verwendet wird auf einer v5e-VM:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker all \
--command='
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
export PJRT_DEVICE=TPU_C_API
export PT_XLA_DEBUG=0
export USE_TORCH=ON
unset LD_PRELOAD
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
Dadurch wird eine Ausgabe generiert, die etwa so aussieht:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
Sehen Sie sich die PyTorch-Tutorials in diesem Dokument an, um mit dem Training mit v5e und PyTorch zu beginnen.
Löschen Sie die TPU und die in die Warteschlange gestellten Ressourcen am Ende der Sitzung. So löschen Sie in die Warteschlange gestellt, löschen Sie das Segment und dann die Ressource in der Warteschlange in zwei Schritten:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Mit diesen beiden Schritten können Sie auch Ressourcenanfragen aus der Warteschlange entfernen, die sich im Status FAILED
befinden.
JAX/FLAX-Beispiele
In den folgenden Abschnitten werden Beispiele für das Trainieren von JAX- und FLAX-Modellen auf TPU v5e.
ImageNet mit v5e trainieren
In diesem Tutorial wird beschrieben, wie ImageNet in v5e mit fiktiven Eingabedaten trainiert wird. Wenn Sie echte Daten verwenden möchten, lesen Sie die README-Datei auf GitHub.
Einrichten
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald die in der Warteschlange befindliche Ressource den Status
ACTIVE
hat:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Wenn sich die QueuedResource im Status
ACTIVE
befindet, sieht die Ausgabe in etwa so aus:state: ACTIVE
Installieren Sie die neueste Version von JAX und jaxlib:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Klonen Sie das ImageNet-Modell und installieren Sie die entsprechenden Anforderungen:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/google/flax.git && cd flax/examples/imagenet && pip install -r requirements.txt && pip install flax==0.7.4'
Zum Generieren von Fake-Daten benötigt das Modell Informationen zu den Dimensionen des Datensatzes. Sie können sie aus den Metadaten des ImageNet-Datasets ermitteln:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='mkdir -p $HOME/flax/.tfds/metadata/imagenet2012/5.1.0 && curl https://raw.githubusercontent.com/tensorflow/datasets/v4.4.0/tensorflow_datasets/testing/metadata/imagenet2012/5.1.0/dataset_info.json --output $HOME/flax/.tfds/metadata/imagenet2012/5.1.0/dataset_info.json'
Modell trainieren
Sobald Sie alle vorherigen Schritte ausgeführt haben, können Sie das Modell trainieren.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd flax/examples/imagenet && JAX_PLATFORMS=tpu python3 imagenet_fake_data_benchmark.py'
TPU und Ressource in der Warteschlange löschen
Löschen Sie die TPU und die in die Warteschlange gestellten Ressourcen am Ende der Sitzung.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Hugging Face FLAX-Modelle
In FLAX implementierte Hugging Face-Modelle funktionieren ohne zusätzliche Anpassungen auf Cloud TPU v5e. In diesem Abschnitt finden Sie eine Anleitung zum Ausführen beliebter Modelle.
ViT auf Imagenette trainieren
In dieser Anleitung erfahren Sie, wie Sie Vision Transformer trainieren. (ViT)-Modell von HuggingFace mit der schnellen KI-Technologie von Imagenette Dataset in Cloud TPU v5e.
Das ViT-Modell war das erste, mit dem ein Transformer-Encoder erfolgreich auf ImageNet trainiert wurde. Im Vergleich zu Convolutional Neural Networks erzielte es hervorragende Ergebnisse. Weitere Informationen finden Sie in den folgenden Ressourcen:
Einrichten
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
Erstellen Sie eine TPU-Ressource:
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
Sie können eine SSH-Verbindung zur TPU-VM herstellen, sobald die Ressource in der Warteschlange steht. befindet sich im Status
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Wenn sich die Ressource in der Warteschlange im Status
ACTIVE
befindet, sieht die Ausgabe ähnlich aus. zu:state: ACTIVE
Installieren Sie JAX und die zugehörige Bibliothek:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Laden Sie das Repository von Hugging Face herunter und installieren Sie die erforderlichen Pakete:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.17.0 && pip install -r examples/flax/vision/requirements.txt'
Laden Sie das Imagenette-Dataset herunter:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
Modell trainieren
Trainieren Sie das Modell mit einem vorab zugeordneten Zwischenspeicher mit 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'
TPU und in die Warteschlange gestellte Ressource löschen
Löschen Sie die TPU und die in die Warteschlange gestellten Ressourcen am Ende der Sitzung.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
ViT-Benchmarking-Ergebnisse
Das Trainingsskript wurde auf v5litepod-4, v5litepod-16 und v5litepod-64 ausgeführt. Die In der folgenden Tabelle sind die Durchsatzraten bei verschiedenen Beschleunigertypen aufgeführt.
Beschleunigertyp | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Epoche | 3 | 3 | 3 |
Globale Batchgröße | 32 | 128 | 512 |
Durchsatz (Beispiele/s) | 263,40 | 429,34 | 470,71 |
Diffusion mit Pokémon trainieren
In dieser Anleitung erfahren Sie, wie Sie das Stable Diffusion-Modell von HuggingFace mit dem Pokémon-Dataset auf Cloud TPU v5e trainieren.
Das Stable Diffusion-Modell ist ein latentes Text-zu-Bild-Modell, mit dem fotorealistische Bilder aus beliebigen Texteingaben generiert werden. Weitere Informationen finden Sie in den folgenden Ressourcen:
Einrichten
Richten Sie einen Speicher-Bucket für die Modellausgabe ein.
gcloud storage buckets create gs://your_bucket
--project=your_project
--location=us-west1
export GCS_BUCKET_NAME=your_bucketUmgebungsvariablen erstellen
export GCS_BUCKET_NAME=your_bucket export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west1-c export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
Beschreibung der Befehls-Flags
Variable Beschreibung GCS_BUCKET_NAME Wird in der Google Cloud Console unter „Cloud Storage“ -> „Buckets“ angezeigt PROJECT_ID Name des Google Cloud-Projekts. Verwenden Sie ein vorhandenes Projekt oder erstellen Sie unter Google Cloud-Projekt einrichten ein neues. ACCELERATOR_TYPE Siehe TPU-Versionen Seite für Ihre TPU-Version. ZONE Welche Zonen unterstützt werden, erfahren Sie im Dokument TPU-Regionen und ‑Zonen. RUNTIME_VERSION Verwenden Sie v2-alpha-tpuv5 für die RUNTIME_VERSION. SERVICE_ACCOUNT Das ist die Adresse Ihres Dienstkontos. Sie finden sie in der Google Cloud Console unter „IAM“ -> „Dienstkonten“. Beispiel: tpu-service-account@myprojectID.iam.gserviceaccount.com TPU_NAME Die vom Nutzer zugewiesene Text-ID der TPU, die erstellt wird, wenn wird die Ressourcenanfrage in der Warteschlange zugewiesen. QUEUED_RESOURCE_ID Die vom Nutzer zugewiesene Text-ID der anstehenden Ressourcenanfrage. Informationen zu Ressourcen in der Warteschlange finden Sie im Dokument Ressourcen in der Warteschlange. QUOTA_TYPE Kann reserved
oderspot
sein. Wenn keine dieser Optionen angegeben ist, wird der Wert für QUOTA_TYPE ist standardmäßigon-demand
. Informationen zu den verschiedenen Arten von Kontingenten, die von Cloud TPU unterstützt werden, finden Sie unter Kontingente.VALID_UNTIL_DURATION Die Dauer, für die die Anfrage gültig ist. Weitere Informationen finden Sie unter Ressourcen in der Warteschlange finden Sie Informationen zur gültigen Dauer. Erstellen Sie eine TPU-Ressource:
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
Sie können eine SSH-Verbindung zur TPU-VM herstellen, sobald sich die Ressource in der Warteschlange im Status
ACTIVE
befindet:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Wenn die Ressource in der Warteschlange den Status
ACTIVE
hat, sieht die Ausgabe in etwa so aus:state: ACTIVE
Installieren Sie JAX und die zugehörige Bibliothek.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
HuggingFace-Repository herunterladen und Installationsvoraussetzungen.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install tensorflow==2.17.0 clu && pip install -U -r examples/text_to_image/requirements_flax.txt'
Modell trainieren
Trainieren Sie das Modell mit einem vorab zugeordneten Zwischenspeicher mit 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project ${PROJECT_ID} --worker=all --command="
git clone https://github.com/google/maxdiffusion
cd maxdiffusion
git reset --hard 57629bcf4fa32fe5a57096b60b09f41f2fa5c35d # This identifies the GitHub commit to use.
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name=your_run base_output_directory=gs://${GCS_BUCKET_NAME}/ enable_profiler=False"
TPU und Ressource in der Warteschlange löschen
Löschen Sie die TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Benchmarking-Ergebnisse für Diffusion
Das Trainingsskript wurde auf v5litepod-4, v5litepod-16 und v5litepod-64 ausgeführt. In der folgenden Tabelle sind die Durchsatzraten aufgeführt.
Beschleunigertyp | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Trainingsschritt | 1500 | 1500 | 1500 |
Globale Batchgröße | 32 | 64 | 128 |
Durchsatz (Beispiele/Sek.) | 36,53 | 43,71 | 49,36 |
GPT2 im OSCAR-Dataset trainieren
In dieser Anleitung erfahren Sie, wie Sie das GPT2-Modell von HuggingFace mit dem OSCAR-Dataset auf Cloud TPU v5e trainieren.
GPT2 ist ein Transformer-Modell, das mit Rohtexten ohne Beschriftungen durch Menschen. Es wurde darauf trainiert, das nächste Wort in Sätzen vorherzusagen. Weitere Informationen finden Sie in den folgenden Ressourcen:
Einrichten
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich die in der Warteschlange befindliche Ressource im Status
ACTIVE
befindet:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Wenn sich die Ressource in der Warteschlange im Status
ACTIVE
befindet, sieht die Ausgabe ähnlich aus. zu:state: ACTIVE
Installieren Sie JAX und die zugehörige Bibliothek.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
HuggingFace-Repository herunterladen und Installationsvoraussetzungen.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install TensorFlow && pip install -r examples/flax/language-modeling/requirements.txt'
Laden Sie Konfigurationen zum Trainieren des Modells herunter.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers/examples/flax/language-modeling && gcloud storage cp gs://cloud-tpu-tpuvm-artifacts/v5litepod-preview/jax/gpt . --recursive'
Modell trainieren
Trainieren Sie das Modell mit einem vorab zugeordneten Zwischenspeicher mit 4 GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers/examples/flax/language-modeling && TPU_PREMAPPED_BUFFER_SIZE=4294967296 JAX_PLATFORMS=tpu python3 run_clm_flax.py --output_dir=./gpt --model_type=gpt2 --config_name=./gpt --tokenizer_name=./gpt --dataset_name=oscar --dataset_config_name=unshuffled_deduplicated_no --do_train --do_eval --block_size=512 --per_device_train_batch_size=4 --per_device_eval_batch_size=4 --learning_rate=5e-3 --warmup_steps=1000 --adam_beta1=0.9 --adam_beta2=0.98 --weight_decay=0.01 --overwrite_output_dir --num_train_epochs=3 --logging_steps=500 --eval_steps=2500'
TPU und Ressource in der Warteschlange löschen
Löschen Sie die TPU und die in die Warteschlange gestellte Ressource am Ende der Sitzung.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Benchmarking-Ergebnisse für GPT2
Das Trainingsskript wurde auf v5litepod-4, v5litepod-16 und v5litepod-64 ausgeführt. In der folgenden Tabelle sind die Durchsatzraten aufgeführt.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Epoche | 3 | 3 | 3 |
Globale Batchgröße | 64 | 64 | 64 |
Durchsatz (Beispiele/Sek.) | 74,60 | 72,97 | 72,62 |
PyTorch/XLA
In den folgenden Abschnitten werden Beispiele zum Trainieren von PyTorch-/XLA-Modellen auf TPU v5e beschrieben.
ResNet mit der PJRT-Laufzeit trainieren
PyTorch/XLA wird von PyTorch 2.0 und höher von XRT zu PjRt migriert. Hier finden Sie die aktualisierte Anleitung zum Einrichten von v5e für PyTorch/XLA-Trainingsarbeitslasten.
Einrichten
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=tpu-name export QUEUED_RESOURCE_ID=queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
Erstellen Sie eine TPU-Ressource:
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --{QUOTA_TYPE}
Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald Ihre QueuedResource befindet sich im Status
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Wenn die Ressource in der Warteschlange den Status
ACTIVE
hat, sieht die Ausgabe in etwa so aus:state: ACTIVE
Torch-/XLA-spezifische Abhängigkeiten installieren
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
ResNet-Modell trainieren
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
date
export PJRT_DEVICE=TPU_C_API
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export XLA_USE_BF16=1
export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
git clone https://github.com/pytorch/xla.git
cd xla/
git reset --hard caf5168785c081cd7eb60b49fe4fffeb894c39d9
python3 test/test_train_mp_imagenet.py --model=resnet50 --fake_data --num_epochs=1 —num_workers=16 --log_steps=300 --batch_size=64 --profile'
TPU und in die Warteschlange gestellte Ressource löschen
Löschen Sie die TPU und die in die Warteschlange gestellten Ressourcen am Ende der Sitzung.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Benchmarkergebnis
In der folgenden Tabelle sind die Benchmarkdurchsätze aufgeführt.
Beschleunigertyp | Durchsatz (Beispiele/Sekunde) |
v5litepod-4 | 4.240 Ex/s |
v5litepod-16 | 10.810 ex/s |
v5litepod-64 | 46.154 ex/s |
GPT2 mit v5e trainieren
In dieser Anleitung erfahren Sie, wie Sie GPT2 in v5e mit dem Repository von HuggingFace auf PyTorch/XLA mit dem Wikitext-Dataset ausführen.
Einrichten
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
Erstellen Sie eine TPU-Ressource:
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich Ihre QueuedResource im
ACTIVE
-Bundesstaat:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Wenn die Ressource in der Warteschlange den Status
ACTIVE
hat, lautet die Ausgabe: etwa so aussehen:state: ACTIVE
Installieren Sie die PyTorch-/XLA-Abhängigkeiten.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get -y update sudo apt install -y libopenblas-base pip3 install torchvision pip3 uninstall -y torch pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
HuggingFace-Repository herunterladen und Installationsvoraussetzungen.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' git clone https://github.com/pytorch/xla.git pip install --upgrade accelerate git clone https://github.com/huggingface/transformers.git cd transformers git checkout ebdb185befaa821304d461ed6aa20a17e4dc3aa2 pip install . git log -1 pip install datasets evaluate scikit-learn '
Konfigurationen des vortrainierten Modells herunterladen
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' gcloud storage cp gs://cloud-tpu-tpuvm-artifacts/config/xl-ml-test/pytorch/gpt2/my_config_2.json transformers/examples/pytorch/language-modeling/ --recursive gcloud storage cp gs://cloud-tpu-tpuvm-artifacts/config/xl-ml-test/pytorch/gpt2/fsdp_config.json transformers/examples/pytorch/language-modeling/'
Modell trainieren
Trainieren Sie das 2B-Modell mit einer Batchgröße von 16.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU_C_API
cd transformers/
export LD_LIBRARY_PATH=/usr/local/lib/
export PT_XLA_DEBUG=0
export USE_TORCH=ON
python3 examples/pytorch/xla_spawn.py \
--num_cores=4 \
examples/pytorch/language-modeling/run_clm.py \
--num_train_epochs=3 \
--dataset_name=wikitext \
--dataset_config_name=wikitext-2-raw-v1 \
--per_device_train_batch_size=16 \
--per_device_eval_batch_size=16 \
--do_train \
--do_eval \
--logging_dir=./tensorboard-metrics \
--cache_dir=./cache_dir \
--output_dir=/tmp/test-clm \
--overwrite_output_dir \
--cache_dir=/tmp \
--config_name=examples/pytorch/language-modeling/my_config_2.json \
--tokenizer_name=gpt2 \
--block_size=1024 \
--optim=adafactor \
--adafactor=true \
--save_strategy=no \
--logging_strategy=no \
--fsdp=full_shard \
--fsdp_config=examples/pytorch/language-modeling/fsdp_config.json'
TPU und Ressource in der Warteschlange löschen
Löschen Sie die TPU und die in die Warteschlange gestellten Ressourcen am Ende der Sitzung.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Benchmark-Ergebnis
Das Trainingsskript wurde auf v5litepod-4, v5litepod-16 und v5litepod-64 ausgeführt. Die Die folgende Tabelle zeigt die Benchmark-Durchsätze für verschiedene Beschleunigertypen.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Epoche | 3 | 3 | 3 |
config | 600 Mio. | 2 Milliarden | 16 Mrd. |
Globale Batchgröße | 64 | 128 | 256 |
Durchsatz (Beispiele/Sek.) | 66 | 77 | 31 |
ViT auf v5e trainieren
In dieser Anleitung erfahren Sie, wie Sie VIT auf v5e mit dem Repository von HuggingFace auf PyTorch/XLA auf dem CIFAR10-Dataset ausführen.
Einrichten
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=tpu-name export QUEUED_RESOURCE_ID=queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald die QueuedResource den Status
ACTIVE
hat:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Wenn die Ressource in der Warteschlange den Status
ACTIVE
hat, lautet die Ausgabe: etwa so aussehen:state: ACTIVE
PyTorch/XLA-Abhängigkeiten installieren
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
Laden Sie das Repository von HuggingFace herunter und erfüllen Sie die Installationsanforderungen.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=" git clone https://github.com/suexu1025/transformers.git vittransformers; \ cd vittransformers; \ pip3 install .; \ pip3 install datasets; \ wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
Modell trainieren
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU_C_API
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export TF_CPP_MIN_LOG_LEVEL=0
export XLA_USE_BF16=1
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
cd vittransformers
python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
--remove_unused_columns=False \
--label_names=pixel_values \
--mask_ratio=0.75 \
--norm_pix_loss=True \
--do_train=true \
--do_eval=true \
--base_learning_rate=1.5e-4 \
--lr_scheduler_type=cosine \
--weight_decay=0.05 \
--num_train_epochs=3 \
--warmup_ratio=0.05 \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--logging_strategy=steps \
--logging_steps=30 \
--evaluation_strategy=epoch \
--save_strategy=epoch \
--load_best_model_at_end=True \
--save_total_limit=3 \
--seed=1337 \
--output_dir=MAE \
--overwrite_output_dir=true \
--logging_dir=./tensorboard-metrics \
--tpu_metrics_debug=true'
TPU und in die Warteschlange gestellte Ressource löschen
Löschen Sie die TPU und die in die Warteschlange gestellten Ressourcen am Ende der Sitzung.
gcloud compute tpus tpu-vm delete ${TPU_NAME}
--project=${PROJECT_ID}
--zone=${ZONE}
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID}
--project=${PROJECT_ID}
--zone=${ZONE}
--quiet
Benchmarkergebnis
Die folgende Tabelle zeigt die Benchmark-Durchsätze für verschiedene Beschleunigertypen.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Epoche | 3 | 3 | 3 |
Globale Batchgröße | 32 | 128 | 512 |
Durchsatz (Beispiele/Sek.) | 201 | 657 | 2.844 |
TensorFlow 2.x
In den folgenden Abschnitten werden Beispiele für das Training von TensorFlow 2.x beschrieben. auf TPU v5e-Modellen.
ResNet auf einer TPU v5e mit einem Host trainieren
In dieser Anleitung wird beschrieben, wie Sie ImageNet auf v5litepod-4
oder v5litepod-8
mit einem fiktiven Dataset trainieren. Wenn Sie ein anderes Dataset verwenden möchten, lesen Sie den Hilfeartikel Dataset vorbereiten.
Einrichten
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your-project-ID export ACCELERATOR_TYPE=v5litepod-4 export ZONE=us-east1-c export RUNTIME_VERSION=tpu-vm-tf-2.15.0-pjrt export TPU_NAME=your-tpu-name export QUEUED_RESOURCE_ID=your-queued-resource-id export QUOTA_TYPE=quota-type
ACCELERATOR_TYPE
kann entwederv5litepod-4
oderv5litepod-8
sein.Erstellen Sie eine TPU-Ressource:
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --${QUOTA_TYPE}
Sie können eine SSH-Verbindung zur TPU-VM herstellen, sobald sich die Ressource in der Warteschlange befindet.
ACTIVE
. Verwenden Sie den folgenden Befehl, um den Status der in der Warteschlange befindlichen Ressource zu prüfen:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Über SSH eine Verbindung zur TPU herstellen
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Umgebungsvariablen festlegen
export MODELS_REPO=/usr/share/tpu/models export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}" export MODEL_DIR=gcp-directory-to-store-model export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet export NEXT_PLUGGABLE_DEVICE_USE_C_API=true export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
Wechseln Sie in das Repository-Verzeichnis für die Modelle und installieren Sie die erforderlichen Pakete.
cd ${MODELS_REPO} && git checkout r2.15.0 pip install -r official/requirements.txt
Modell trainieren
Führen Sie das Trainingsskript aus.
python3 official/vision/train.py \
--tpu=local \
--experiment=resnet_imagenet \
--mode=train_and_eval \
--config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
--model_dir=${MODEL_DIR} \
--params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
TPU und Ressource in der Warteschlange löschen
TPU löschen
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Ressourcenanfrage in der Warteschlange löschen
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Resnet auf einem v5e mit mehreren Hosts trainieren
In dieser Anleitung wird beschrieben, wie Sie ImageNet auf v5litepod-16
oder größer mit
ein fiktives Dataset. Wenn Sie ein anderes Dataset verwenden möchten, finden Sie weitere Informationen unter Dataset vorbereiten.
Erstellen Sie Umgebungsvariablen:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-east1-c export RUNTIME_VERSION=tpu-vm-tf-2.15.0-pod-pjrt export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your-queued-resource-id export QUOTA_TYPE=quota-type
ACCELERATOR_TYPE
kann entwederv5litepod-16
oder größer sein.-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --${QUOTA_TYPE}
Sie können eine SSH-Verbindung zu Ihrer TPU-VM herstellen, sobald sich die in der Warteschlange befindliche Ressource im Status
ACTIVE
befindet. Prüfen Sie den Status der in die Warteschlange gestellten Ressource mithilfe des folgenden Befehl:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Über SSH eine Verbindung zu Ihrer TPU (Worker Zero) herstellen
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Umgebungsvariablen festlegen
export MODELS_REPO=/usr/share/tpu/models export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}" export MODEL_DIR=gcp-directory-to-store-model export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet export TPU_LOAD_LIBRARY=0 export TPU_NAME=your_tpu_name
Wechseln Sie in das Repository-Verzeichnis für die Modelle und installieren Sie die erforderlichen Pakete.
cd $MODELS_REPO && git checkout r2.15.0 pip install -r official/requirements.txt
Modell trainieren
Führen Sie das Trainingsskript aus.
python3 official/vision/train.py \
--tpu=${TPU_NAME} \
--experiment=resnet_imagenet \
--mode=train_and_eval \
--model_dir=${MODEL_DIR} \
--params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*, task.validation_data.input_path=${DATA_DIR}/validation*"
TPU und Ressource in der Warteschlange löschen
TPU löschen
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Ressourcenanfrage in der Warteschlange löschen
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet