Fehlerbehebung für JAX – TPU

Dieser Leitfaden enthält Verweise auf Informationen zur Fehlerbehebung für JAX, damit Sie Probleme erkennen und beheben können, die beim Trainieren von JAX-Modellen auf Cloud TPU auftreten können.

Eine allgemeine Anleitung zum Einstieg in Cloud TPU finden Sie in der JAX-Kurzanleitung.

Allgemeine JAX-Probleme

Wenn bei der Entwicklung Ihres Trainingsmodells oder beim Training mit JAX Probleme auftreten, lesen Sie die FAQs zu JAX.

Allgemeinere Programmierfehler, die beim Schreiben einer Trainingsanwendung mit JAX auftreten können, finden Sie unter JAX-Fehler.

Profilerstellung für JAX-Leistung

Mithilfe der unter Profilerstellung für JAX-Leistung beschriebenen Tools können Sie nachvollziehen, wie Ihre TPU-Ressourcen genutzt werden.

Arbeitsspeicherprobleme beheben

Mit dem JAX Device Memory Profiler können Sie die Nutzung des Arbeitsspeichers überwachen, aber nicht direkt festlegen, wie er verwendet wird.

Der Gerätespeicher-Profiler kann für Folgendes verwendet werden:

Sie können nicht angeben, wie TPU-Arbeitsspeicher für bestimmte Vorgänge zugewiesen wird. Weitere Informationen zu JAX-spezifischen TPU-Leistungsproblemen finden Sie unter Leistungshinweise für die Verwendung von TPUs mit JAX.

TPU-Probleme beheben

Wie kann ich prüfen, ob die TPU ausgeführt wird?

Details

Alle Vorgänge werden auf der TPU ausgeführt, solange JAX nicht "Keine GPU/TPU gefunden" oder "Fallback auf CPU" ausgibt.

Sie können prüfen, ob die TPU aktiv ist. Dazu rufen Sie jax.devices() auf, wo mehrere TPU-Geräte angezeigt werden sollen, oder prüfen programmatisch mit assert jax.devices()[0].platform == 'tpu'.

RuntimeError: Back-End "tpu" kann nicht initialisiert werden: UNAVAILABLE: Keine TPU-Plattform verfügbar.

Details

Diese Laufzeitfehlermeldung und/oder folgende Ergebnisse in /tmp/tpu_logs/tpu_driver.WARNING auf der TPU-VM: W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx können darauf hinweisen, dass Sie die falsche TPU-VM-Version ausführen.

Stellen Sie sicher, dass Sie die aktuelle JAX-Laufzeitversion ausführen, und versuchen Sie es noch einmal.