Résoudre les problèmes liés à JAX – TPU

Ce guide fournit des pointeurs vers les informations de dépannage JAX pour vous aider à identifier et résoudre les problèmes que vous pouvez rencontrer lors de l'entraînement de modèles JAX sur Cloud TPU.

Pour obtenir des conseils plus généraux pour faire vos premiers pas avec Cloud TPU, consultez le guide de démarrage rapide de JAX.

Problèmes d'ordre général avec JAX

Si vous rencontrez des problèmes lors du développement de votre modèle d'entraînement ou de votre entraînement avec JAX, consultez les questions fréquentes sur JAX.

Pour connaître les erreurs de programmation plus générales que vous pouvez rencontrer lors de l'écriture d'une application d'entraînement avec JAX, consultez la section Erreurs JAX.

Profiler les performances JAX

Vous pouvez comprendre l'utilisation de vos ressources TPU à l'aide des outils décrits dans la section Profiler les performances JAX.

Résoudre les problèmes de mémoire

Vous pouvez surveiller l'utilisation de la mémoire avec le Profileur de mémoire de l'appareil JAX, mais vous ne pouvez pas gérer directement son utilisation.

Le Profileur de mémoire de l'appareil peut être utilisé pour:

Vous ne pouvez pas spécifier la manière dont la mémoire du TPU est allouée à des opérations spécifiques. Pour en savoir plus sur les problèmes de performances des TPU spécifiques à JAX, consultez les notes de performance pour l'utilisation de TPU avec JAX.

Résoudre les problèmes liés aux TPU

Comment vérifier que le TPU est en cours d'exécution ?

Détail

Tout sera exécuté sur le TPU tant que JAX n'affiche pas "Aucun GPU/TPU trouvé, retour au processeur".

Vous pouvez vérifier que le TPU est actif en consultant jax.devices(), qui affiche plusieurs appareils TPU, ou en effectuant une vérification automatisée avec : assert jax.devices()[0].platform == 'tpu'.

Erreur d'exécution: Impossible d'initialiser le backend "tpu": UNAVAILABLE: aucune plate-forme TPU n'est disponible.

Détails

Ce message d'erreur d'exécution et/ou le résultat suivant dans /tmp/tpu_logs/tpu_driver.WARNING sur la VM TPU : W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx peuvent indiquer que vous exécutez la mauvaise version de VM TPU.

Vérifiez que vous utilisez la version d'exécution JAX actuelle, puis réessayez.