Risoluzione dei problemi relativi a JAX - TPU
Questa guida fornisce indicazioni sulle informazioni per la risoluzione dei problemi di JAX per aiutarti a identificare e risolvere i problemi che potresti riscontrare durante l'addestramento dei modelli JAX su Cloud TPU.
Per una guida più generale su come iniziare a utilizzare Cloud TPU, consulta la guida rapida di JAX.
Problemi generali di JAX
Se riscontri problemi durante lo sviluppo del modello di addestramento o l'addestramento con JAX, consulta le domande frequenti su JAX.
Per errori di programmazione più generali che potresti riscontrare durante la scrittura di un'applicazione di addestramento con JAX, consulta Errori JAX.
Profilazione delle prestazioni di JAX
Puoi capire in che modo vengono utilizzate le risorse TPU utilizzando gli strumenti descritti in Profiling delle prestazioni di JAX.
Risoluzione dei problemi di memoria
Puoi monitorare l'utilizzo della memoria con JAX Device Memory Profiler, ma non puoi gestirne direttamente l'utilizzo.
Lo strumento di analisi della memoria del dispositivo può essere utilizzato per:
- Scopri quali array ed eseguibili sono nella memoria della TPU in un determinato momento oppure
- Rintraccia le perdite di memoria.
Non puoi specificare in che modo la memoria TPU viene allocata per operazioni specifiche. Per ulteriori informazioni sui problemi di prestazioni delle TPU specifici di JAX, consulta Note sul rendimento per l'utilizzo delle TPU con JAX.
Risolvere i problemi relativi alla TPU
Come faccio a verificare che la TPU sia in esecuzione?
Dettagli
Tutto verrà eseguito sulla TPU a condizione che JAX non stampi "Nessuna GPU/TPU trovata, ritorno alla CPU".
Puoi verificare che la TPU sia attiva controllando jax.devices()
, dove dovresti vedere diversi dispositivi TPU, oppure verificare tramite programmazione con: assert jax.devices()[0].platform == 'tpu'
.
RuntimeError: impossibile inizializzare il backend "tpu": NON DISPONIBILE: nessuna piattaforma TPU disponibile.
Dettagli
Questo messaggio di errore di runtime e/o il seguente messaggio in /tmp/tpu_logs/tpu_driver.WARNING
nella VM TPU:
W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx
può indicare che stai utilizzando la versione errata della VM TPU.
Verifica di utilizzare la versione corrente del runtime JAX e riprova.