Solução de problemas do JAX: TPU

Neste guia, fornecemos orientações para solução de problemas do JAX que ajudam você a identificar e resolver problemas que podem ocorrer ao treinar modelos JAX no Cloud TPU.

Para um guia mais geral sobre como começar a usar o Cloud TPU, consulte o Guia de início rápido do JAX.

Problemas gerais do JAX

Se você tiver problemas durante o desenvolvimento do seu modelo de treinamento ou do treinamento com o JAX, consulte as Perguntas frequentes sobre o JAX (em inglês).

Para erros de programação mais gerais que podem ser encontrados ao escrever um aplicativo de treinamento com o JAX, consulte Erros do JAX (em inglês).

Como criar perfis de desempenho do JAX

É possível entender como os recursos de TPU estão sendo usados com as ferramentas descritas na seção Como criar perfis de desempenho do JAX (em inglês).

Solução de problemas de memória

Você pode monitorar como a memória é usada com o JAX Device Memory Profiler, mas não pode gerenciar diretamente como ela é usada.

O Memory Profiler do dispositivo pode ser usado para:

Não é possível especificar como a memória da TPU é alocada para operações específicas. Para mais informações sobre problemas de desempenho de TPU específicos do JAX, consulte Notas de desempenho para o uso de TPUs com JAX.

Como solucionar problemas de TPU

Como posso verificar se a TPU está em execução?

Detalhes

Tudo será executado na TPU, desde que o JAX não exiba "No GPU/TPU found, falling back to CPU" (Nenhuma GPU/TPU encontrada, recorrendo à CPU).

Para verificar se a TPU está ativa, consulte jax.devices(), em que serão exibidos vários dispositivos TPU. Também é possível verificar programaticamente com assert jax.devices()[0].platform == 'tpu'.

RuntimeError: não foi possível inicializar o back-end 'tpu': UNAVAILABLE: nenhuma plataforma de TPU disponível.

Detalhes

Essa mensagem de erro no ambiente de execução e/ou a descoberta do seguinte em /tmp/tpu_logs/tpu_driver.WARNING na VM da TPU: W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx podem indicar que você está executando a versão incorreta da VM da TPU.

Verifique se você está executando a versão atual do ambiente de execução do JAX e tente de novo.