排查 JAX 问题 - TPU

本指南提供了 JAX 问题排查信息,可帮助您识别和解决在 Cloud TPU 上训练 JAX 模型时可能遇到的问题。

如需查看如何开始使用 Cloud TPU 的通用指南,请参阅 JAX 快速入门

常见的 JAX 问题

如果您在开发训练模型或使用 JAX 进行训练时遇到问题,请参阅 JAX 常见问题解答

如需了解使用 JAX 编写训练应用时可能遇到的更多常规编程错误,请参阅 JAX 错误

对 JAX 性能进行性能分析

您可以使用剖析 JAX 性能中所述的工具,了解 TPU 资源的使用方式。

排查内存问题

您可以使用 JAX 设备内存分析器监控内存的使用情况,但无法直接管理内存的使用情况。

设备内存分析器可用于:

您无法指定如何为特定操作分配 TPU 内存。如需详细了解 JAX 特有的 TPU 性能问题,请参阅将 TPU 与 JAX 搭配使用的性能说明

排查 TPU 问题

如何验证 TPU 是否正在运行?

详情

只要 JAX 未输出“找不到 GPU/TPU,回退到 CPU”,所有一切都会在 TPU 上运行。

您可以通过查看 jax.devices() 来验证 TPU 是否处于活跃状态,您应该会看到显示了多个 TPU 设备,也可以使用 assert jax.devices()[0].platform == 'tpu' 以程序化方式进行验证。

RuntimeError:无法初始化后端“tpu”:UNAVAILABLE:没有可用的 TPU 平台。

详细信息

此运行时错误消息和/或在 TPU 虚拟机上的 /tmp/tpu_logs/tpu_driver.WARNING 中找到以下内容:W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx 可能表示您运行的 TPU 虚拟机版本不正确。

验证您是否运行的是当前 JAX 运行时版本,然后重试。