Risoluzione dei problemi relativi a JAX - TPU
Questa guida fornisce indicazioni per la risoluzione dei problemi di JAX per aiutarti a identificare e risolvere problemi che potresti riscontrare durante l'addestramento di modelli JAX su Cloud TPU.
Per una guida più generale su come iniziare a utilizzare Cloud TPU, consulta la guida rapida di JAX.
Problemi generici con JAX
Se riscontri problemi durante lo sviluppo del tuo modello di addestramento o l'addestramento con JAX, consulta le Domande frequenti su JAX.
Per errori di programmazione più generali che potresti riscontrare durante la scrittura di un'applicazione di addestramento con JAX, consulta la sezione Errori JAX.
Profilazione delle prestazioni di JAX
Puoi comprendere come vengono utilizzate le risorse TPU utilizzando gli strumenti descritti in Profilazione delle prestazioni JAX.
Risoluzione dei problemi di memoria
Puoi monitorare la modalità di utilizzo della memoria con il JAX Device Memory Profiler, ma non puoi gestirne direttamente l'utilizzo.
Il Profiler memoria del dispositivo può essere utilizzato per:
- Capire quali array ed eseguibili sono presenti nella memoria TPU in un determinato momento oppure
- Individua le fughe di memoria.
Non puoi specificare la modalità di allocazione della memoria TPU per operazioni specifiche. Per ulteriori informazioni sui problemi di prestazioni delle TPU specifici di JAX, consulta Note sulle prestazioni per l'utilizzo di TPU con JAX.
Risoluzione dei problemi relativi alla TPU
Come posso verificare che la TPU sia in esecuzione?
Dettagli
Tutto verrà eseguito sulla TPU purché JAX non mostri "Nessuna GPU/TPU trovata, passaggio alla CPU".
Puoi verificare che la TPU sia attiva osservando jax.devices()
, dove
dovresti visualizzare diversi dispositivi TPU, oppure puoi verificarla
in modo programmatico con: assert jax.devices()[0].platform == 'tpu'
.
RuntimeError: Impossibile inizializzare il backend 'tpu': UNAVAILABLE: Nessuna piattaforma TPU disponibile.
Dettagli
Questo messaggio di errore di runtime e/o quanto segue in /tmp/tpu_logs/tpu_driver.WARNING
sulla VM TPU:
W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx
può indicare che stai eseguendo la versione della VM TPU errata.
Verifica di eseguire la versione corrente del runtime JAX e riprova.