排查 JAX 问题 - TPU
本指南提供了 JAX 问题排查信息,可帮助您识别和解决在 Cloud TPU 上训练 JAX 模型时可能遇到的问题。
如需更广泛的 Cloud TPU 使用入门指南,请参阅 JAX 快速入门。
常见的 JAX 问题
如果您在开发训练模型或使用 JAX 进行训练时遇到问题,请参阅 JAX 常见问题解答。
如需了解使用 JAX 编写训练应用时可能遇到的更多常规编程错误,请参阅 JAX 错误。
分析 JAX 性能
您可以使用剖析 JAX 性能中所述的工具,了解 TPU 资源的使用方式。
排查内存问题
您可以使用 JAX 设备内存性能分析器监控内存的使用方式,但无法直接管理内存的使用方式。
设备内存性能分析器可用于:
- 找出给定时间 TPU 内存中有哪些数组和可执行文件,或者
- 跟踪内存泄漏。
您无法指定如何为特定操作分配 TPU 内存。如需详细了解 JAX 专用 TPU 性能问题,请参阅关于将 TPU 与 JAX 搭配使用时的性能注意事项。
排查 TPU 问题
如何验证 TPU 是否正在运行?
详情
只要 JAX 没有输出“No GPU/TPU found, falling back to CPU”(未找到 GPU/TPU,回退到 CPU),所有内容都将在 TPU 上运行。
您可以通过查看 jax.devices()
(其中应显示多个 TPU 设备)或使用以下程序化方式进行验证:assert jax.devices()[0].platform == 'tpu'
,来验证 TPU 是否处于活动状态。
RuntimeError: Unable to initialize backend 'tpu': UNAVAILABLE: No TPU Platform available.
详细信息
如果您看到以下运行时错误消息,并且/或者在 TPU 虚拟机上的 /tmp/tpu_logs/tpu_driver.WARNING
中找到以下内容:W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx
,则可能表示您运行的 TPU 虚拟机版本不正确。
请验证您是否运行的是当前的 JAX 运行时版本,然后重试。