Risoluzione dei problemi relativi a JAX - TPU

Questa guida fornisce indicazioni per la risoluzione dei problemi di JAX per aiutarti a identificare e risolvere problemi che potresti riscontrare durante l'addestramento di modelli JAX su Cloud TPU.

Per una guida più generale su come iniziare a utilizzare Cloud TPU, consulta la guida rapida di JAX.

Problemi generici con JAX

Se riscontri problemi durante lo sviluppo del tuo modello di addestramento o l'addestramento con JAX, consulta le Domande frequenti su JAX.

Per errori di programmazione più generali che potresti riscontrare durante la scrittura di un'applicazione di addestramento con JAX, consulta la sezione Errori JAX.

Profilazione delle prestazioni di JAX

Puoi comprendere come vengono utilizzate le risorse TPU utilizzando gli strumenti descritti in Profilazione delle prestazioni JAX.

Risoluzione dei problemi di memoria

Puoi monitorare la modalità di utilizzo della memoria con il JAX Device Memory Profiler, ma non puoi gestirne direttamente l'utilizzo.

Il Profiler memoria del dispositivo può essere utilizzato per:

Non puoi specificare la modalità di allocazione della memoria TPU per operazioni specifiche. Per ulteriori informazioni sui problemi di prestazioni delle TPU specifici di JAX, consulta Note sulle prestazioni per l'utilizzo di TPU con JAX.

Risoluzione dei problemi relativi alla TPU

Come posso verificare che la TPU sia in esecuzione?

Dettagli

Tutto verrà eseguito sulla TPU purché JAX non mostri "Nessuna GPU/TPU trovata, passaggio alla CPU".

Puoi verificare che la TPU sia attiva osservando jax.devices(), dove dovresti visualizzare diversi dispositivi TPU, oppure puoi verificarla in modo programmatico con: assert jax.devices()[0].platform == 'tpu'.

RuntimeError: Impossibile inizializzare il backend 'tpu': UNAVAILABLE: Nessuna piattaforma TPU disponibile.

Dettagli

Questo messaggio di errore di runtime e/o quanto segue in /tmp/tpu_logs/tpu_driver.WARNING sulla VM TPU: W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx può indicare che stai eseguendo la versione della VM TPU errata.

Verifica di eseguire la versione corrente del runtime JAX e riprova.