排解 JAX - TPU 問題

本指南提供 JAX 疑難排解資訊的指標,協助您在 Cloud TPU 上訓練 JAX 模型時,找出並解決可能遇到的問題。

如需 Cloud TPU 的一般入門指南,請參閱 JAX 快速入門導覽課程

一般 JAX 問題

如果您在使用 JAX 開發訓練模型或訓練時遇到問題,請參閱 JAX 常見問題

如要進一步瞭解使用 JAX 編寫訓練應用程式時可能遇到的一般程式設計錯誤,請參閱「JAX 錯誤」。

剖析 JAX 效能

您可以使用「剖析 JAX 效能」一文所述的工具,瞭解 TPU 資源的使用情形。

排解記憶體問題

您可以使用 JAX Device Memory Profiler 監控記憶體的使用方式,但無法直接管理記憶體的使用方式。

裝置記憶體分析器可用於:

您無法指定 TPU 記憶體如何分配給特定作業。如要進一步瞭解 JAX 專屬 TPU 效能問題,請參閱「使用 JAX 搭配 TPU 的效能注意事項」。

排解 TPU 問題

如何確認 TPU 是否運作?

詳細資料

只要 JAX 未顯示「No GPU/TPU found, falling back to CPU.」(找不到 GPU/TPU,改用 CPU),所有內容都會在 TPU 上執行。

您可以查看 jax.devices() 來驗證 TPU 是否處於活動狀態,您應該會看到幾個 TPU 裝置顯示在其中。您也可以透過程式輔助方式驗證,方法是使用 assert jax.devices()[0].platform == 'tpu'

RuntimeError:無法初始化後端「tpu」:UNAVAILABLE:沒有可用的 TPU 平台。

詳細資料

這個執行階段錯誤訊息和/或在 TPU VM 的 /tmp/tpu_logs/tpu_driver.WARNING 中找到以下內容: W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx 可能表示您執行的 TPU VM 版本不正確。

請確認您執行的是目前的 JAX 執行階段版本,然後再試一次。