Solução de problemas do JAX: TPU

Neste guia, fornecemos orientações para solução de problemas do JAX que ajudam você a identificar e resolver problemas que podem ocorrer ao treinar modelos JAX no Cloud TPU.

Para um guia mais geral sobre como começar a usar o Cloud TPU, consulte o Guia de início rápido do JAX.

Problemas gerais do JAX

Se você tiver problemas durante o desenvolvimento do seu modelo de treinamento ou do treinamento com o JAX, consulte as Perguntas frequentes sobre o JAX (em inglês).

Para erros de programação mais gerais que podem ser encontrados ao escrever um aplicativo de treinamento com o JAX, consulte Erros do JAX (em inglês).

Como criar perfis de desempenho do JAX

É possível entender como os recursos de TPU estão sendo usados com as ferramentas descritas na seção Como criar perfis de desempenho do JAX (em inglês).

Solução de problemas de memória

É possível usar o JAX Device Memory Profiler (em inglês) para ver como a memória da TPU está sendo usada. Ele pode ser usado para:

Como solucionar problemas de TPU

Como posso verificar se a TPU está em execução?

Detalhes

Tudo será executado na TPU, desde que o JAX não exiba "No GPU/TPU found, falling back to CPU" (Nenhuma GPU/TPU encontrada, recorrendo à CPU).

Para verificar se a TPU está ativa, consulte jax.devices(), em que serão exibidos vários dispositivos TPU. Também é possível verificar programaticamente com assert jax.devices()[0].platform == 'tpu'.