Résoudre les problèmes liés à JAX – TPU

Ce guide fournit des pointeurs vers les informations de dépannage JAX pour vous aider à identifier et résoudre les problèmes que vous pouvez rencontrer lors de l'entraînement de modèles JAX sur Cloud TPU.

Pour obtenir des conseils plus généraux pour faire vos premiers pas avec Cloud TPU, consultez le guide de démarrage rapide de JAX.

Problèmes d'ordre général avec JAX

Si vous rencontrez des problèmes lors du développement de votre modèle d'entraînement ou de votre entraînement avec JAX, consultez les questions fréquentes sur JAX.

Pour connaître les erreurs de programmation plus générales que vous pouvez rencontrer lors de l'écriture d'une application d'entraînement avec JAX, consultez la section Erreurs JAX.

Profiler les performances JAX

Vous pouvez comprendre l'utilisation de vos ressources TPU à l'aide des outils décrits dans la section Profiler les performances JAX.

Résoudre les problèmes de mémoire

Vous pouvez utiliser le Profileur de mémoire de l'appareil JAX pour voir comment la mémoire TPU est utilisée. Il peut être utilisé pour :

Résoudre les problèmes liés aux TPU

Comment vérifier que le TPU est en cours d'exécution ?

Détail

Tout sera exécuté sur le TPU tant que JAX n'affiche pas "Aucun GPU/TPU trouvé, retour au processeur".

Vous pouvez vérifier que le TPU est actif en consultant jax.devices(), qui affiche plusieurs appareils TPU, ou en effectuant une vérification automatisée avec : assert jax.devices()[0].platform == 'tpu'.