排查 JAX 问题 - TPU

本指南提供了 JAX 问题排查信息,可帮助您识别和解决在 Cloud TPU 上训练 JAX 模型时可能遇到的问题。

如需了解如何开始使用 Cloud TPU,请参阅 JAX 快速入门

常规 JAX 问题

如果您在开发训练模型或使用 JAX 进行训练时遇到问题,请参阅 JAX 常见问题解答

如需了解使用 JAX 编写训练应用时可能遇到的更多常规编程错误,请参阅 JAX 错误

剖析 JAX 性能

您可以使用剖析 JAX 性能中所述的工具,了解 TPU 资源的使用方式。

排查内存问题

您可以使用 JAX 设备内存分析器查看 TPU 内存的使用方式。它可用于:

排查 TPU 问题

如何验证 TPU 是否正在运行?

详情

只要 JAX 未输出“找不到 GPU/TPU,回退到 CPU”,TPU 就会运行。

如需验证 TPU 是否处于活跃状态,您可以查看 jax.devices()(应该看到多个 TPU 设备),或者使用 assert jax.devices()[0].platform == 'tpu' 以编程方式验证。