Troubleshooting JAX - TPU
This guide provides pointers to JAX troubleshooting information to help you identify and resolve problems you might encounter while training JAX models on Cloud TPU.
For a more general guide to getting started with Cloud TPU, see the JAX quickstart.
General JAX issues
If you run into issues while developing your training model or training with JAX, see the JAX FAQ.
For more general programming errors you might encounter when writing a training application with JAX, see JAX Errors.
Profiling JAX performance
You can understand how your TPU resources are being utilized using the tools described in Profiling JAX performance.
Troubleshooting memory issues
You can use the JAX Device Memory Profiler to see how TPU memory is being used. It can be used to:
- Figure out which arrays and executables are in TPU memory at a given time, or
- Track down memory leaks.
Troubleshooting TPU issues
How can I verify that the TPU is running?
Everything will be run on the TPU as long as JAX doesn't print "No GPU/TPU found, falling back to CPU."
You can verify the TPU is active by either looking at
you should see several TPU devices displayed, or verify
assert jax.devices().platform == 'tpu'.