Berechnung mit JAX auf einer Cloud TPU-VM ausführen

In diesem Dokument erhalten Sie eine kurze Einführung in die Arbeit mit JAX und Cloud TPU.

Hinweise

Bevor Sie die Befehle in diesem Dokument ausführen, müssen Sie ein Google Cloud-Konto erstellen, die Google Cloud CLI installieren und den gcloud-Befehl konfigurieren. Weitere Informationen finden Sie unter Cloud TPU-Umgebung einrichten.

Cloud TPU-VM mit gcloud erstellen

  1. Definieren Sie einige Umgebungsvariablen, um die Befehle nutzerfreundlicher zu gestalten.

    export PROJECT_ID=your-project
    export ACCELERATOR_TYPE=v5p-8
    export ZONE=us-east5-a
    export RUNTIME_VERSION=v2-alpha-tpuv5
    export TPU_NAME=your-tpu-name

    Beschreibungen von Umgebungsvariablen

    PROJECT_ID
    Ihre Google Cloud Projekt-ID.
    ACCELERATOR_TYPE
    Mit dem Beschleunigertyp geben Sie die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu den unterstützten Beschleunigertypen für jede TPU-Version finden Sie unter TPU-Versionen.
    ZONE
    Die Zone, in der Sie die Cloud TPU erstellen möchten.
    RUNTIME_VERSION
    Die Version der Cloud TPU-Laufzeit. Weitere Informationen finden Sie unter TPU-VM-Images
    .
    TPU_NAME
    Der vom Nutzer zugewiesene Name für Ihre Cloud TPU.
  2. Erstellen Sie Ihre TPU-VM, indem Sie den folgenden Befehl in einer Cloud Shell oder Ihrem Computerterminal ausführen, in dem die Google Cloud CLI installiert ist.

    $ gcloud compute tpus tpu-vm create $TPU_NAME \
    --project=$PROJECT_ID \
    --zone=$ZONE \
    --accelerator-type=$ACCELERATOR_TYPE \
    --version=$RUNTIME_VERSION

Verbindung zur Cloud TPU-VM herstellen

Stellen Sie über SSH mit dem folgenden Befehl eine Verbindung zu Ihrer TPU-VM her:

$ gcloud compute tpus tpu-vm ssh $TPU_NAME \
--project=$PROJECT_ID \
--zone=$ZONE

JAX auf der Cloud TPU-VM installieren

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Systemüberprüfung

Prüfen Sie, ob JAX auf die TPU zugreifen und grundlegende Vorgänge ausführen kann:

  1. Starten Sie den Python 3-Interpreter:

    (vm)$ python3
    >>> import jax
  2. Rufen Sie die Anzahl der verfügbaren TPU-Kerne auf:

    >>> jax.device_count()

Die Anzahl der TPU-Kerne wird angezeigt. Die angezeigte Anzahl der Kerne hängt von der verwendeten TPU-Version ab. Weitere Informationen finden Sie unter TPU-Versionen.

Führen Sie eine Berechnung durch:

>>> jax.numpy.add(1, 1)

Das Ergebnis von "numpy add" wird angezeigt:

Ausgabe des Befehls:

Array(2, dtype=int32, weak_type=true)

Beenden Sie den Python-Interpreter:

>>> exit()

JAX-Code auf einer TPU-VM ausführen

Sie können jetzt jeden JAX-Code ausführen. Die flax-Beispiele sind ein guter Ausgangspunkt, um Standard-ML-Modelle in JAX auszuführen. So trainieren Sie beispielsweise ein einfaches MNIST-Convolutional Network (eine Art „faltendes neuronales Netzwerk“):

  1. Abhängigkeiten für Flax-Beispiele installieren

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. FLAX installieren

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. FLAX MNIST-Trainingsskript ausführen

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

Das Script lädt den Datensatz herunter und startet das Training. Die Ausgabe des Scripts sollte in etwa so aussehen:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Bereinigen

Mit den folgenden Schritten vermeiden Sie, dass Ihrem Google Cloud -Konto die in dieser Anleitung verwendeten Ressourcen in Rechnung gestellt werden:

Wenn Sie mit Ihrer TPU-VM fertig sind, führen Sie die folgenden Schritte aus, um Ihre Ressourcen zu bereinigen.

  1. Trennen Sie die Verbindung zur Compute Engine-Instanz, sofern noch nicht geschehen:

    (vm)$ exit
  2. Löschen Sie Ihre Cloud TPU.

    $ gcloud compute tpus tpu-vm delete $TPU_NAME \
      --project=$PROJECT_ID \
      --zone=$ZONE
  3. Überprüfen Sie mit dem folgenden Befehl, ob die Ressourcen gelöscht wurden. Achten Sie darauf, dass Ihre TPU nicht mehr aufgeführt wird. Der Löschvorgang kann einige Minuten dauern.

    $ gcloud compute tpus tpu-vm list \
      --zone=$ZONE

Hinweise zur Leistung

Im Folgenden finden Sie einige wichtige Details, die für die Verwendung von TPUs in JAX besonders relevant sind.

Padding

Eine der häufigsten Ursachen für eine langsame Leistung auf TPUs stellt versehentliches Padding dar:

  • Arrays in der Cloud TPU sind gekachelt. Dies bedeutet, dass eine der Dimensionen auf ein Vielfaches von 8 und eine andere auf ein Vielfaches von 128 aufgefüllt wird.
  • Die Matrixmultiplikationseinheit funktioniert am besten mit Paaren großer Matrizen, die die Notwendigkeit von Padding minimieren.

bfloat16 dtype

Standardmäßig verwendet die Matrixmultiplikation in JAX auf TPUs bfloat16 mit der Akkumulation float32. Dies kann mit dem Genauigkeitsargument für relevante "jax.numpy"-Funktionsaufrufe (matmul, dot, einsum usw.) gesteuert werden. Beispiele:

  • precision=jax.lax.Precision.DEFAULT: verwendet die gemischte bfloat16-Genauigkeit (am schnellsten)
  • precision=jax.lax.Precision.HIGH verwendet mehrere MXU-Durchläufe, um eine höhere Genauigkeit zu erreichen
  • precision=jax.lax.Precision.HIGHEST verwendet noch mehr MXU-Durchläufe, um eine vollständige float32-Genauigkeit zu erreichen

JAX fügt außerdem den bfloat16-dtype hinzu, mit dem Sie explizit Arrays in bfloat16 umwandeln können, z. B. jax.numpy.array(x, dtype=jax.numpy.bfloat16).

Nächste Schritte

Weitere Informationen zu Cloud TPU finden Sie unter: