Soluciona problemas de JAX: TPU

En esta guía, se proporcionan indicadores de información para solucionar problemas de JAX que te ayudarán a identificar y resolver problemas que podrías encontrar mientras entrenas modelos de JAX en Cloud TPU.

Para obtener una guía más general a fin de comenzar a usar Cloud TPU, consulta la guía de inicio rápido de JAX.

Problemas generales de JAX

Si tienes problemas mientras desarrollas tu modelo de entrenamiento o entrenas con JAX, consulta las Preguntas frecuentes de JAX.

Para obtener información sobre los errores de programación más generales que puedes encontrar cuando escribes una aplicación de entrenamiento con JAX, consulta Errores de JAX.

Cómo generar perfiles del rendimiento de JAX

Puedes comprender cómo se utilizan tus recursos de TPU con las herramientas que se describen en Cómo generar perfiles del rendimiento de JAX.

Soluciona problemas de memoria

Puedes supervisar cómo se usa la memoria con el profilador de memoria del dispositivo JAX, pero no puedes administrar directamente cómo se usa.

El Generador de perfiles de memoria del dispositivo se puede usar para lo siguiente:

No puedes especificar cómo se asigna la memoria de la TPU para operaciones específicas. Para obtener más información sobre los problemas de rendimiento de las TPU específicos de JAX, consulta Notas de rendimiento para usar TPU con JAX.

Soluciona problemas de TPU

¿Cómo puedo verificar que la TPU se esté ejecutando?

Detalles

Todo se ejecutará en la TPU, siempre y cuando JAX no imprima “No se encontró ninguna GPU/TPU, se usará la CPU”.

Para verificar que la TPU esté activa, puedes consultar jax.devices(), donde deberías ver varios dispositivos TPU, o bien verificarlo de forma programática con assert jax.devices()[0].platform == 'tpu'.

RuntimeError: No se pudo inicializar el backend "tpu": NO DISPONIBLE: No hay una plataforma de TPU disponible.

Detalles

Este mensaje de error de tiempo de ejecución o encontrar lo siguiente en /tmp/tpu_logs/tpu_driver.WARNING en la VM de TPU: W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx puede indicar que estás ejecutando la versión incorrecta de la VM de TPU.

Verifica que estés ejecutando la versión actual del entorno de ejecución de JAX y vuelve a intentarlo.