Cloud TPU VM JAX-Kurzanleitung

Dieses Dokument bietet eine kurze Einführung in die Arbeit mit JAX und Cloud TPU.

Melden Sie sich bei Ihrem Google-Konto an. Wenn Sie noch kein Konto haben, registrieren Sie sich für ein neues Konto. Wählen Sie in der Google Cloud Console auf der Projektauswahlseite ein Cloud-Projekt aus oder erstellen Sie eines. Die Abrechnung für das Projekt muss aktiviert sein.

Google Cloud SDK installieren

Das Google Cloud SDK enthält Tools und Bibliotheken für die Interaktion mit Produkten und Diensten von Google Cloud. Weitere Informationen finden Sie unter Google Cloud SDK installieren.

gcloud-Befehl konfigurieren

Führen Sie die folgenden Befehle aus, um gcloud für die Verwendung Ihres GCP-Projekts zu konfigurieren und für die TPU-VM-Vorschau erforderliche Komponenten zu installieren.

  $ gcloud config set account your-email-account
  $ gcloud config set project project-id

Cloud TPU API aktivieren

  1. Aktivieren Sie die Cloud TPU API mit dem folgenden gcloud-Befehl in Cloud Shell. (Sie können es auch über die Google Cloud Console aktivieren.)

    $ gcloud services enable tpu.googleapis.com
    
  2. Führen Sie den folgenden Befehl aus, um eine Dienstidentität zu erstellen.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Cloud TPU-VM mit gcloud erstellen

Bei Cloud TPU-VMs werden Ihr Modell und Ihr Code direkt auf dem TPU-Hostcomputer ausgeführt. Stellen Sie eine SSH-Verbindung zum TPU-Host her. Sie können beliebigen Code ausführen, Pakete installieren, Logs ansehen und Fehler direkt auf dem TPU-Host beheben.

  1. Erstellen Sie Ihre TPU-VM, indem Sie den folgenden Befehl in einer GCP-Cloud Shell oder auf Ihrem Computerterminal ausführen, auf dem das Google Cloud SDK installiert ist.

    (vm)$ gcloud alpha compute tpus tpu-vm create tpu-name \
    --zone europe-west4-a \
    --accelerator-type v3-8 \
    --version v2-alpha

    Pflichtfelder

    zone
    Die Zone, in der Sie die Cloud TPU erstellen möchten.
    accelerator-type
    Der Typ der zu erstellenden Cloud TPU.
    version
    Die Cloud TPU-Laufzeitversion. Legen Sie diesen Wert auf "v2-alpha" fest, wenn Sie JAX für einzelne TPU-Geräte, Pod-Slices oder ganze Pods verwenden.

Verbindung zur Cloud TPU-VM herstellen

Stellen Sie mit dem folgenden Befehl eine SSH-Verbindung zu Ihrer TPU-VM her:

$ gcloud alpha compute tpus tpu-vm ssh tpu-name --zone europe-west4-a

Pflichtfelder

tpu_name
Der Name der TPU-VM, mit der Sie eine Verbindung herstellen.
zone
Die Zone, in der Sie die Cloud TPU erstellt haben.

JAX auf der Cloud TPU-VM installieren

(vm)$ pip3 install --upgrade jax jaxlib

Systemprüfung

Testen Sie, ob alles korrekt installiert wurde. Dazu überprüfen Sie, ob JAX die Cloud TPU-Kerne sieht und grundlegende Vorgänge ausführen kann:

Starten Sie den Python 3-Interpreter:

(vm)$ python3
>>> import jax

Lassen Sie sich die Anzahl der verfügbaren TPU-Kerne anzeigen:

>>> jax.device_count()

Die Anzahl der TPU-Kerne wird angezeigt. Dies sollte 8 sein.

Führen Sie eine einfache Berechnung durch:

>>> jax.numpy.add(1, 1)

Das Ergebnis der Anzeige "numpy" wird angezeigt:

Ausgabe des Befehls:

DeviceArray(2, dtype=int32)

Beenden Sie den Python-Interpreter:

>>> exit()

JAX-Code auf einer TPU-VM ausführen

Sie können jetzt einen beliebigen JAX-Code ausführen. Die flax-Beispiele sind ein guter Ausgangspunkt für die Ausführung von standardmäßigen ML-Modellen in JAX. So trainieren Sie beispielsweise ein einfaches MNIST-Faltungsnetzwerk:

  1. TensorFlow-Datasets installieren

    (vm)$ pip install --upgrade clu
    
  2. Installieren Sie FLAX.

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user -e flax
    
  3. FLAX MNIST-Trainingsskript ausführen

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

    Die Skriptausgabe sollte so aussehen:

    I0513 21:09:35.448946 140431261813824 train.py:125] train epoch: 1, loss: 0.2312, accuracy: 93.00
    I0513 21:09:36.402860 140431261813824 train.py:176] eval epoch: 1, loss: 0.0563, accuracy: 98.05
    I0513 21:09:37.321380

Bereinigen

Wenn Sie mit der TPU-VM fertig sind, führen Sie die folgenden Schritte aus, um Ihre Ressourcen zu bereinigen.

  1. Trennen Sie die Verbindung zur Compute Engine-Instanz, sofern noch nicht geschehen:

    (vm)$ exit
    
  2. Löschen Sie Ihre Cloud TPU.

    $ gcloud alpha compute tpus tpu-vm delete tpu-name \
      --zone europe-west4-a
    
  3. Prüfen Sie mit dem folgenden Befehl, ob die Ressourcen gelöscht wurden: Achten Sie darauf, dass Ihre TPU nicht mehr aufgeführt ist. Der Löschvorgang kann einige Minuten dauern.

Hinweise zur Leistung

Im Folgenden finden Sie einige wichtige Details, die für die Verwendung von TPUs in JAX besonders relevant sind.

Padding

Eine der häufigsten Ursachen für eine langsame Leistung auf TPUs ist die Verwendung eines kontinuierlichen Paddings:

  • Arrays in der Cloud TPU sind gekachelt. Dies bedeutet, dass eine der Dimensionen auf ein Vielfaches von 8 und eine andere auf ein Vielfaches von 128 aufgefüllt wird.
  • Die Matrixmultiplikationseinheit funktioniert am besten mit Paaren großer Matrizen, die die erforderliche Auffüllung minimieren.

bfloat16 dtype

Standardmäßig verwendet die Matrixmultiplikation in JAX auf TPUs bfloat16 mit der Akkumulation float32. Dies kann mit dem Genauigkeitsargument für relevante Jax.numpy-Funktionsaufrufe (matamul, dot, sumsum usw.) gesteuert werden. Beispiele:

  • precision=jax.lax.Precision.DEFAULT: Verwendet eine Mischung aus bfloat16 (Genauigkeit) (schnellste)
  • precision=jax.lax.Precision.HIGH: Verwendet mehrere MXU-Tickets, um eine höhere Genauigkeit zu erreichen
  • precision=jax.lax.Precision.HIGHEST: verwendet noch mehr MXU-Passkarten, um eine volle float32-Genauigkeit zu erreichen

JAX fügt auch den bfloat16 dtype hinzu, mit dem Sie Arrays explizit in bfloat16 umwandeln können. Beispiel: jax.numpy.array(x, dtype=jax.numpy.bfloat16).

JAX in einem Colab ausführen

Wenn Sie JAX-Code in einem Colab-Notebook ausführen, erstellt Colab automatisch einen alten TPU-Knoten. TPU-Knoten haben eine andere Architektur. Weitere Informationen finden Sie unter Systemarchitektur.