Solución de problemas de JAX: TPU

En esta guía, se proporcionan indicaciones para la información de solución de problemas de JAX para ayudarte a identificar y resolver problemas que puedes encontrar mientras entrenas modelos de JAX en Cloud TPU.

Si deseas obtener una guía más general para comenzar a usar Cloud TPU, consulta la guía de inicio rápido de JAX.

Problemas generales de JAX

Si tienes problemas durante el desarrollo del modelo de entrenamiento o el entrenamiento con JAX, consulta las Preguntas frecuentes sobre JAX.

Para ver errores de programación más generales que puedes encontrar cuando escribes una aplicación de entrenamiento con JAX, consulta Errores de JAX.

Cómo generar perfiles de rendimiento de JAX

Puedes comprender cómo se usan los recursos TPU mediante las herramientas que se describen en Crea perfiles del rendimiento de JAX.

Soluciona problemas de memoria

Puedes supervisar la manera en que se usa la memoria con el Generador de perfiles de memoria del dispositivo JAX, pero no puedes administrar directamente la forma en que se usa.

El Generador de perfiles de memoria del dispositivo se puede usar para lo siguiente:

No puedes especificar cómo se asigna la memoria de TPU para operaciones específicas. Para obtener más información sobre problemas de rendimiento específicos de las TPU de JAX, consulta las Notas de rendimiento para usar TPU con JAX.

Soluciona problemas de TPU

¿Cómo puedo verificar que la TPU se esté ejecutando?

Detalles

Todo se ejecutará en la TPU, siempre que JAX no imprima "No se encontró ninguna GPU/TPU, se recurrirá a la CPU".

Puedes verificar que la TPU esté activa mirando jax.devices(), donde deberías ver varios dispositivos de TPU, o verifícalo de manera programática con assert jax.devices()[0].platform == 'tpu'.

RuntimeError: No se puede inicializar el backend "tpu": NO DISPONIBLE: No hay plataforma de TPU disponible.

Detalles

Este mensaje de error del entorno de ejecución o encontrar lo siguiente en /tmp/tpu_logs/tpu_driver.WARNING en la VM de TPU: W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx puede indicar que estás ejecutando la versión incorrecta de la VM de TPU.

Verifica que estés ejecutando la versión actual del entorno de ejecución de JAX y vuelve a intentarlo.