JAX - TPU のトラブルシューティング

このガイドでは、Cloud TPU で JAX モデルをトレーニングする際に発生する可能性のある問題を特定して解決するため、JAX のトラブルシューティング情報について説明します。

Cloud TPU を使い始める際の一般的なガイドについては、JAX クイックスタートをご覧ください。

JAX に関する一般的な問題

トレーニング モデルの開発中や JAX でのトレーニング中に問題が発生した場合は、JAX のよくある質問をご覧ください。

JAX を使用してトレーニング アプリケーションを作成するときに発生する可能性のある一般的なプログラミング エラーについては、JAX エラーをご覧ください。

JAX パフォーマンスをプロファイリングする

JAX パフォーマンスのプロファイリングで説明されているツールを使用して、TPU リソースの使用状況を把握できます。

メモリの問題のトラブルシューティング

JAX Device Memory Profiler でメモリの使用状況をモニタリングできますが、その使用状況について直接管理することはできません。

Device Memory Profiler は、次の目的で使用できます。

TPU メモリを特定のオペレーションに割り当てる方法は指定できません。JAX 固有の TPU パフォーマンスに関する問題の詳細については、JAX で TPU を使用する場合のパフォーマンスに関するメモをご覧ください。

TPU の問題のトラブルシューティング

TPU が実行されていることを確認する方法

詳細

JAX から「No GPU/TPU found, falling back to CPU.」と出力されない限り、すべてが TPU で実行されます。

TPU がアクティブであることを確認するには、jax.devices() で複数の TPU デバイスが表示されていることを確認するか、assert jax.devices()[0].platform == 'tpu' を使用してプログラムで確認します。

RuntimeError: バックエンド 'tpu' を初期化できません: UNAVAILABLE: 使用できない TPU プラットフォームがあります。

詳細

このランタイム エラー メッセージまたは TPU VM 上の /tmp/tpu_logs/tpu_driver.WARNING で次を確認する: W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx は、誤った TPU VM バージョンを実行していることを示しています。

現在の JAX ランタイム バージョンを実行していることを確認し、再試行してください。