Solução de problemas do JAX: TPU

Neste guia, fornecemos orientações para solução de problemas do JAX que ajudam você a identificar e resolver problemas que podem ocorrer ao treinar modelos JAX no Cloud TPU.

Para um guia mais geral sobre como começar a usar o Cloud TPU, consulte o Guia de início rápido do JAX.

Problemas gerais do JAX

Se você tiver problemas durante o desenvolvimento do seu modelo de treinamento ou do treinamento com o JAX, consulte as Perguntas frequentes sobre o JAX (em inglês).

Para erros de programação mais gerais que podem ser encontrados ao escrever um aplicativo de treinamento com o JAX, consulte Erros do JAX (em inglês).

Como criar perfis de desempenho do JAX

É possível entender como os recursos de TPU estão sendo usados com as ferramentas descritas na seção Como criar perfis de desempenho do JAX (em inglês).

Solução de problemas de memória

É possível monitorar como a memória é usada com o JAX Device Memory Profiler, mas não é possível gerenciar diretamente como ela é usada.

O Device Memory Profiler pode ser usado para:

Não é possível especificar como a memória da TPU é alocada para operações específicas. Para mais informações sobre problemas de desempenho da TPU específicos do JAX, consulte Observações sobre desempenho para usar TPUs com o JAX.

Como solucionar problemas de TPU

Como posso verificar se a TPU está em execução?

Detalhes

Tudo será executado na TPU, desde que o JAX não exiba "No GPU/TPU found, falling back to CPU" (Nenhuma GPU/TPU encontrada, recorrendo à CPU).

Para verificar se a TPU está ativa, consulte jax.devices(), em que serão exibidos vários dispositivos TPU. Também é possível verificar programaticamente com assert jax.devices()[0].platform == 'tpu'.

RuntimeError: não foi possível inicializar o back-end "tpu": INDISPONÍVEL: nenhuma plataforma TPU disponível.

Detalhes

Esta mensagem de erro de execução e/ou encontrar o seguinte em /tmp/tpu_logs/tpu_driver.WARNING na VM de TPU: W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx pode indicar que você está executando a versão errada da VM de TPU.

Verifique se você está executando a versão atual do ambiente de execução do JAX e tente novamente.