Berechnung auf einer Cloud TPU-VM mit JAX ausführen

Dieses Dokument bietet eine kurze Einführung in die Arbeit mit JavaX und Cloud TPU.

Bevor Sie dieser Kurzanleitung folgen, müssen Sie ein Google Cloud Platform-Konto erstellen, die Google Cloud CLI installieren und den Befehl gcloud konfigurieren. Weitere Informationen finden Sie unter Konto und Cloud TPU-Projekt einrichten.

Google Cloud CLI installieren

Die Google Cloud CLI enthält Tools und Bibliotheken für die Interaktion mit Google Cloud-Produkten und -Diensten. Weitere Informationen finden Sie unter Google Cloud CLI installieren.

gcloud-Befehl konfigurieren

Führen Sie die folgenden Befehle aus, um gcloud für die Verwendung Ihres Google Cloud-Projekts zu konfigurieren und Komponenten zu installieren, die für die TPU-VM-Vorschau erforderlich sind.

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

Cloud TPU API aktivieren

  1. Aktivieren Sie die Cloud TPU API mit dem folgenden gcloud-Befehl in Cloud Shell. Sie können es auch über die Google Cloud Console aktivieren.

    $ gcloud services enable tpu.googleapis.com
    
  2. Führen Sie den folgenden Befehl aus, um eine Dienstidentität zu erstellen.

    $ gcloud beta services identity create --service tpu.googleapis.com
    

Cloud TPU-VM mit gcloud erstellen

Mit Cloud TPU-VMs werden Modell und Code direkt auf der TPU-Hostmaschine ausgeführt. Sie können eine SSH-Verbindung direkt zum TPU-Host herstellen. Sie können beliebigen Code direkt auf dem TPU-Host ausführen, Pakete installieren, Logs ansehen und Code debuggen.

  1. Erstellen Sie Ihre TPU-VM. Führen Sie dazu den folgenden Befehl in einer Cloud Shell oder auf dem Computerterminal aus, auf dem die Google Cloud CLI installiert ist.

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base
    

    Pflichtfelder

    zone
    Die Zone, in der Sie Ihre Cloud TPU erstellen möchten.
    accelerator-type
    Der Beschleunigertyp gibt die Version und Größe der Cloud TPU an, die Sie erstellen möchten. Weitere Informationen zu unterstützten Beschleunigertypen für die einzelnen TPU-Versionen finden Sie unter TPU-Versionen.
    version
    Die Cloud TPU-Softwareversion. Verwenden Sie für alle TPU-Typen tpu-ubuntu2204-base.

Verbindung zur Cloud TPU-VM herstellen

Stellen Sie eine SSH-Verbindung zu Ihrer TPU-VM her, indem Sie den folgenden Befehl verwenden:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central2-b

Pflichtfelder

tpu_name
Der Name der TPU-VM, mit der Sie eine Verbindung herstellen möchten.
zone
Die Zone, in der Sie die Cloud TPU erstellt haben.

JAX auf der Cloud TPU-VM installieren

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Systemüberprüfung

Prüfen Sie, ob JAX auf die TPU zugreifen und grundlegende Vorgänge ausführen kann:

Starten Sie den Python 3-Interpreter:

(vm)$ python3
>>> import jax

Rufen Sie die Anzahl der verfügbaren TPU-Kerne auf:

>>> jax.device_count()

Die Anzahl der TPU-Kerne wird angezeigt. Wenn Sie eine v4 TPU verwenden, sollte dieser 4 lauten. Wenn Sie eine v2- oder v3-TPU verwenden, sollte dieser 8 lauten.

Führen Sie eine einfache Berechnung durch:

>>> jax.numpy.add(1, 1)

Das Ergebnis von "numpy add" wird angezeigt:

Ausgabe des Befehls:

Array(2, dtype=int32, weak_type=true)

Beenden Sie den Python-Interpreter:

>>> exit()

JAX-Code auf einer TPU-VM ausführen

Sie können nun jeden beliebigen JAX-Code ausführen. Die flax-Beispiele sind ein guter Ausgangspunkt, um Standard-ML-Modelle in JAX auszuführen. So trainieren Sie beispielsweise ein einfaches MNIST Convolutional Network:

  1. Flax-Beispielabhängigkeiten installieren

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
    
  2. FLAX installieren

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
    
  3. FLAX MNIST-Trainingsskript ausführen

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5
    

Das Skript lädt das Dataset herunter und beginnt mit dem Training. Die Skriptausgabe sollte in etwa so aussehen:

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

Bereinigen

Wenn Sie mit Ihrer TPU-VM fertig sind, führen Sie die folgenden Schritte aus, um Ihre Ressourcen zu bereinigen.

  1. Trennen Sie die Verbindung zur Compute Engine-Instanz, sofern noch nicht geschehen:

    (vm)$ exit
    
  2. Löschen Sie Ihre Cloud TPU.

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central2-b
    
  3. Überprüfen Sie mit dem folgenden Befehl, ob die Ressourcen gelöscht wurden. Achten Sie darauf, dass Ihre TPU nicht mehr aufgeführt wird. Der Löschvorgang kann einige Minuten dauern.

    $ gcloud compute tpus tpu-vm list \
      --zone=us-central2-b
    

Hinweise zur Leistung

Im Folgenden finden Sie einige wichtige Details, die für die Verwendung von TPUs in JAX besonders relevant sind.

Padding

Eine der häufigsten Ursachen für eine langsame Leistung auf TPUs stellt versehentliches Padding dar:

  • Arrays in der Cloud TPU sind gekachelt. Dies bedeutet, dass eine der Dimensionen auf ein Vielfaches von 8 und eine andere auf ein Vielfaches von 128 aufgefüllt wird.
  • Die Matrixmultiplikationseinheit funktioniert am besten mit Paaren großer Matrizen, die die Notwendigkeit von Padding minimieren.

bfloat16 dtype

Standardmäßig verwendet die Matrixmultiplikation in JAX auf TPUs bfloat16 mit FLOAT32-Akkumulation. Dies kann mit dem Genauigkeitsargument für relevante "jax.numpy"-Funktionsaufrufe (matmul, dot, einsum usw.) gesteuert werden. Beispiele:

  • precision=jax.lax.Precision.DEFAULT: verwendet gemischte bfloat16-Genauigkeit (am schnellsten)
  • precision=jax.lax.Precision.HIGH: verwendet mehrere MXU-Pässe, um eine höhere Genauigkeit zu erzielen
  • precision=jax.lax.Precision.HIGHEST: verwendet noch mehr MXU-Karten/Tickets, um eine vollständige Float32-Genauigkeit zu erreichen

JAX fügt außerdem den dtype bfloat16 hinzu, mit dem Sie Arrays explizit in bfloat16 umwandeln können, z. B. jax.numpy.array(x, dtype=jax.numpy.bfloat16).

JAX in einem Colab ausführen

Wenn Sie JAX-Code in einem Colab-Notebook ausführen, erstellt Colab automatisch einen Legacy-TPU-Knoten. TPU-Knoten haben eine andere Architektur. Weitere Informationen finden Sie unter Systemarchitektur.

Nächste Schritte

Weitere Informationen zu Cloud TPU finden Sie unter: