Mit Pax auf einer TPU mit einzelnem Host trainieren


In diesem Dokument erhalten Sie eine kurze Einführung in die Arbeit mit Pax auf einer TPU mit einem einzelnen Host (v2-8, v3-8, v4-8).

Pax ist ein Framework zum Konfigurieren und Ausführen von ML-Tests auf Basis von JAX. Pax konzentriert sich darauf, ML im großen Maßstab zu vereinfachen, indem Infrastrukturkomponenten mit vorhandenen ML-Frameworks geteilt und die Modellierungsbibliothek Praxis für Modularität verwendet werden.

Lernziele

  • TPU-Ressourcen für das Training einrichten
  • Pax auf einer TPU mit einem einzelnen Host installieren
  • Transformerbasiertes SPMD-Modell mit Pax trainieren

Hinweise

Führen Sie die folgenden Befehle aus, um gcloud für die Verwendung Ihres Cloud TPU-Projekts zu konfigurieren und die Komponenten zu installieren, die zum Trainieren eines Modells mit Pax auf einer TPU mit einem einzelnen Host erforderlich sind.

Google Cloud CLI installieren

Die Google Cloud CLI enthält Tools und Bibliotheken für die Interaktion mit Produkten und Diensten der Google Cloud CLI. Wenn Sie sie noch nicht installiert haben, folgen Sie der Anleitung unter Google Cloud CLI installieren, um sie jetzt zu installieren.

gcloud-Befehl konfigurieren

(Führen Sie gcloud auth list aus, um Ihre verfügbaren Konten aufzurufen.)

$ gcloud config set account account

$ gcloud config set project project-id

Cloud TPU API aktivieren

Aktivieren Sie die Cloud TPU API mit dem folgenden gcloud-Befehl in der Cloud Shell. Sie können sie auch in der Google Cloud Console aktivieren.

$ gcloud services enable tpu.googleapis.com

Führen Sie den folgenden Befehl aus, um eine Dienstidentität (ein Dienstkonto) zu erstellen.

$ gcloud beta services identity create --service tpu.googleapis.com

TPU-VM erstellen

Bei Cloud TPU-VMs werden Ihr Modell und Ihr Code direkt auf der TPU-VM ausgeführt. Sie stellen eine SSH-Verbindung direkt zur TPU-VM her. Sie können beliebigen Code ausführen, Pakete installieren, Logs ansehen und Code direkt auf der TPU-VM debuggen.

Erstellen Sie Ihre TPU-VM, indem Sie den folgenden Befehl in einer Cloud Shell oder Ihrem Computerterminal ausführen, in dem die Google Cloud CLI installiert ist.

Legen Sie den zone basierend auf der Verfügbarkeit in Ihrem Vertrag fest. Weitere Informationen finden Sie unter TPU-Regionen und ‑Zonen.

Legen Sie die Variable accelerator-type auf „v2-8“, „v3-8“ oder „v4-8“ fest.

Legen Sie die Variable version für TPU-Versionen v2 und v3 auf tpu-vm-base oder für TPUs v4 auf tpu-vm-v4-base fest.

$ gcloud compute tpus tpu-vm create tpu-name \
--zone zone \
--accelerator-type accelerator-type \
--version version

Verbindung zur Google Cloud TPU-VM herstellen

Stellen Sie eine SSH-Verbindung zu Ihrer TPU-VM her, indem Sie den folgenden Befehl verwenden:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone zone

Wenn Sie bei der VM angemeldet sind, ändert sich die Shell-Eingabeaufforderung von username@projectname in username@vm-name:

Pax auf der Google Cloud TPU-VM installieren

Installieren Sie Pax, JAX und libtpu mit den folgenden Befehlen auf Ihrer TPU-VM:

(vm)$ python3 -m pip install -U pip \
python3 -m pip install paxml jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Systemüberprüfung

Prüfen Sie, ob alles korrekt installiert ist. Prüfen Sie dazu, ob JAX die TPU-Kerne sieht:

(vm)$ python3 -c "import jax; print(jax.device_count())"

Die Anzahl der TPU-Kerne wird angezeigt. Diese sollte 8 sein, wenn Sie eine v2-8 oder v3-8 verwenden, oder 4, wenn Sie eine v4-8 verwenden.

Pax-Code auf einer TPU-VM ausführen

Sie können jetzt jeden Pax-Code ausführen. Die lm_cloud-Beispiele sind ein guter Ausgangspunkt, um Modelle in Pax auszuführen. Mit den folgenden Befehlen wird beispielsweise ein Transformer-basiertes SPMD-Sprachmodell mit 2 Milliarden Parametern anhand synthetischer Daten trainiert.

Die folgenden Befehle zeigen die Trainingsausgabe für ein SPMD-Sprachmodell. Es trainiert in etwa 20 Minuten 300 Schritte.

(vm)$ python3 .local/lib/python3.10/site-packages/paxml/main.py  --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps --job_log_dir=job_log_dir

Für die Version 4–8 sollte die Ausgabe Folgendes enthalten:

Verluste und Schrittzeiten

summary tensor at step=step_# loss = loss
summary tensor at step=step_# Schritte pro Sekunde x

Bereinigen

Damit Ihrem Google Cloud-Konto die in dieser Anleitung verwendeten Ressourcen nicht in Rechnung gestellt werden, löschen Sie entweder das Projekt, das die Ressourcen enthält, oder Sie behalten das Projekt und löschen die einzelnen Ressourcen.

Wenn Sie mit Ihrer TPU-VM fertig sind, führen Sie die folgenden Schritte aus, um Ihre Ressourcen zu bereinigen.

Trennen Sie die Verbindung zur Compute Engine-Instanz, sofern noch nicht geschehen:

(vm)$ exit

Löschen Sie Ihre Cloud TPU.

$ gcloud compute tpus tpu-vm delete tpu-name  --zone zone

Nächste Schritte

Weitere Informationen zu Cloud TPU finden Sie unter: