JAX 문제 해결 - TPU
이 가이드에서는 Cloud TPU에서 JAX 모델을 학습시키는 동안 발생할 수 있는 문제를 식별하고 해결하는 데 도움이 되는 JAX 문제 해결 정보를 제공합니다.
Cloud TPU 시작에 대한 보다 일반적인 가이드는 JAX 빠른 시작을 참조하세요.
일반적인 JAX 문제
학습 모델을 개발하거나 JAX로 학습할 때 문제가 발생하면 JAX FAQ를 참조하세요.
JAX로 학습 애플리케이션을 작성할 때 발생할 수 있는 보다 일반적인 프로그래밍 오류는 JAX 오류를 참조하세요.
JAX 성능 프로파일링
JAX 성능 프로파일링에 설명된 도구를 사용하여 TPU 리소스가 사용되는 방식을 파악할 수 있습니다.
메모리 문제 해결
JAX 기기 메모리 프로파일러를 사용하여 메모리가 사용되는 방식을 모니터링할 수 있지만 메모리 사용 방식을 직접 관리할 수는 없습니다.
기기 메모리 프로파일러는 다음 작업에 사용할 수 있습니다.
- 특정 시점에 TPU 배열에 있는 배열 및 실행 파일 파악
- 메모리 누수 추적
특정 작업에 TPU 메모리가 할당되는 방식을 지정할 수 없습니다. JAX 특정 TPU 성능 문제에 대한 자세한 내용은 JAX에서 TPU 사용 시 성능 참고사항을 참조하세요.
TPU 문제 해결
TPU가 실행 중인지 확인하려면 어떻게 해야 하나요?
세부정보
JAX가 'GPU/TPU를 찾을 수 없으며 CPU로 돌아갑니다.'를 출력하지 않는 한 모든 항목이 TPU에서 실행됩니다.
여러 TPU 기기가 표시되는 경우 jax.devices()
를 확인하여 TPU가 활성 상태인지 확인하거나, assert jax.devices()[0].platform == 'tpu'
를 사용하여 프로그래매틱 방식으로 확인할 수 있습니다.
RuntimeError: 백엔드 'tpu'를 초기화할 수 없음: UNAVAILABLE: 사용 가능한 TPU 플랫폼이 없음
세부정보
이 런타임 오류 메시지 또는 TPU VM의 /tmp/tpu_logs/tpu_driver.WARNING
에서 W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx
를 발견하면 잘못된 TPU VM 버전을 실행하고 있는 것일 수 있습니다.
현재 JAX 런타임 버전을 실행 중인지 확인하고 다시 시도합니다.