JAX - TPU のトラブルシューティング

このガイドでは、Cloud TPU で JAX モデルをトレーニングする際に発生する可能性のある問題を特定して解決するため、JAX のトラブルシューティング情報について説明します。

Cloud TPU を使い始める際の一般的なガイドについては、JAX クイックスタートをご覧ください。

JAX に関する一般的な問題

トレーニング モデルの開発中や JAX でのトレーニング中に問題が発生した場合は、JAX のよくある質問をご覧ください。

JAX を使用してトレーニング アプリケーションを作成するときに発生する可能性のある一般的なプログラミング エラーについては、JAX エラーをご覧ください。

JAX パフォーマンスをプロファイリングする

JAX パフォーマンスのプロファイリングで説明されているツールを使用して、TPU リソースの使用状況を把握できます。

メモリの問題のトラブルシューティング

JAX Device Memory Profiler でメモリの使用状況をモニタリングできますが、その使用状況について直接管理することはできません。

Device Memory Profiler を使用すると、次のことができます。

TPU メモリを特定のオペレーションに割り当てる方法は指定できません。JAX 固有の TPU パフォーマンスに関する問題の詳細については、JAX で TPU を使用する場合のパフォーマンスに関するメモをご覧ください。

TPU の問題のトラブルシューティング

TPU が実行されていることを確認する方法

詳細

JAX から「No GPU/TPU found, falling back to CPU.」と出力されない限り、すべてが TPU で実行されます。

TPU がアクティブであることを確認するには、jax.devices() で複数の TPU デバイスが表示されていることを確認するか、assert jax.devices()[0].platform == 'tpu' を使用してプログラムで確認します。

RuntimeError: バックエンド「tpu」を初期化できませんでした: 利用不可: 使用可能な TPU プラットフォームはありません。

詳細

このランタイム エラー メッセージや、TPU VM の /tmp/tpu_logs/tpu_driver.WARNINGW1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx が見つかった場合、間違った TPU VM バージョンを実行している可能性があります。

現在の JAX ランタイム バージョンを実行していることを確認し、再試行します。