Solução de problemas do JAX: TPU
Neste guia, fornecemos orientações para solução de problemas do JAX que ajudam você a identificar e resolver problemas que podem ocorrer ao treinar modelos JAX no Cloud TPU.
Para um guia mais geral sobre como começar a usar o Cloud TPU, consulte o Guia de início rápido do JAX.
Problemas gerais do JAX
Se você tiver problemas durante o desenvolvimento do seu modelo de treinamento ou do treinamento com o JAX, consulte as Perguntas frequentes sobre o JAX (em inglês).
Para erros de programação mais gerais que podem ser encontrados ao escrever um aplicativo de treinamento com o JAX, consulte Erros do JAX (em inglês).
Como criar perfis de desempenho do JAX
É possível entender como os recursos de TPU estão sendo usados com as ferramentas descritas na seção Como criar perfis de desempenho do JAX (em inglês).
Solução de problemas de memória
Você pode monitorar como a memória é usada com o JAX Device Memory Profiler, mas não pode gerenciar diretamente como ela é usada.
O Memory Profiler do dispositivo pode ser usado para:
- descobrir quais matrizes e executáveis estão na memória da TPU (em inglês) em um determinado momento; ou
- rastrear vazamentos de memória (em inglês).
Não é possível especificar como a memória da TPU é alocada para operações específicas. Para mais informações sobre problemas de desempenho de TPU específicos do JAX, consulte Notas de desempenho para o uso de TPUs com JAX.
Como solucionar problemas de TPU
Como posso verificar se a TPU está em execução?
Detalhes
Tudo será executado na TPU, desde que o JAX não exiba "No GPU/TPU found, falling back to CPU" (Nenhuma GPU/TPU encontrada, recorrendo à CPU).
Para verificar se a TPU está ativa, consulte jax.devices()
, em que
serão exibidos vários dispositivos TPU. Também é possível verificar
programaticamente com assert jax.devices()[0].platform == 'tpu'
.
RuntimeError: não foi possível inicializar o back-end 'tpu': UNAVAILABLE: nenhuma plataforma de TPU disponível.
Detalhes
Essa mensagem de erro no ambiente de execução e/ou a descoberta do seguinte em /tmp/tpu_logs/tpu_driver.WARNING
na VM da TPU: W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx
podem indicar que você está executando a versão incorreta da VM da TPU.
Verifique se você está executando a versão atual do ambiente de execução do JAX e tente de novo.