排查 JAX 问题 - TPU

本指南提供了 JAX 问题排查信息,可帮助您识别和解决在 Cloud TPU 上训练 JAX 模型时可能遇到的问题。

如需查看更通用的 Cloud TPU 使用入门指南,请参阅 JAX 快速入门

常规 JAX 问题

如果您在开发训练模型或使用 JAX 进行训练时遇到问题,请参阅 JAX 常见问题解答

如需了解使用 JAX 编写训练应用时可能遇到的更多常规编程错误,请参阅 JAX 错误

分析 JAX 性能

您可以使用剖析 JAX 性能中所述的工具,了解 TPU 资源的使用方式。

排查内存问题

您可以使用 JAX 设备内存分析器监控内存的使用情况,但无法直接管理其使用方式。

设备内存分析器可用于:

您无法指定如何为特定操作分配 TPU 内存。如需详细了解特定于 JAX 的 TPU 性能问题,请参阅将 TPU 与 JAX 搭配使用的性能说明

排查 TPU 问题

如何验证 TPU 是否正在运行?

详情

只要 JAX 不输出“找不到 GPU/TPU,回退到 CPU”,一切都会在 TPU 上运行。

您可以通过查看 jax.devices()(其中应该会显示多个 TPU 设备)来验证 TPU 处于活跃状态,或者使用 assert jax.devices()[0].platform == 'tpu' 以编程方式进行验证。

RuntimeError:无法初始化后端“tpu”:UNAVAILABLE:没有可用的 TPU 平台。

详细信息

此运行时错误消息和/或在 TPU 虚拟机的 /tmp/tpu_logs/tpu_driver.WARNING 中找到以下内容:W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx 可能表示您运行的 TPU 虚拟机版本有误。

请验证您运行的是否为当前 JAX 运行时版本,然后重试。