JAX 문제 해결 - TPU

이 가이드에서는 Cloud TPU에서 JAX 모델을 학습시키는 동안 발생할 수 있는 문제를 식별하고 해결하는 데 도움이 되는 JAX 문제 해결 정보를 제공합니다.

Cloud TPU 시작에 대한 보다 일반적인 가이드는 JAX 빠른 시작을 참조하세요.

일반적인 JAX 문제

학습 모델을 개발하거나 JAX로 학습할 때 문제가 발생하면 JAX FAQ를 참조하세요.

JAX로 학습 애플리케이션을 작성할 때 발생할 수 있는 보다 일반적인 프로그래밍 오류는 JAX 오류를 참조하세요.

JAX 성능 프로파일링

JAX 성능 프로파일링에 설명된 도구를 사용하여 TPU 리소스가 사용되는 방식을 파악할 수 있습니다.

메모리 문제 해결

JAX 기기 메모리 프로파일러를 사용하여 메모리가 사용되는 방식을 모니터링할 수 있지만 메모리 사용 방식을 직접 관리할 수는 없습니다.

기기 메모리 프로파일러를 사용하면 다음을 수행할 수 있습니다.

특정 작업에 TPU 메모리를 할당하는 방법을 지정할 수 없습니다. JAX 특정 TPU 성능 문제에 대한 자세한 내용은 JAX에서 TPU 사용 시 성능 참고사항을 참조하세요.

TPU 문제 해결

TPU가 실행 중인지 확인하려면 어떻게 해야 하나요?

세부정보

JAX가 'GPU/TPU를 찾을 수 없으며 CPU로 돌아갑니다.'를 출력하지 않는 한 모든 항목이 TPU에서 실행됩니다.

여러 TPU 기기가 표시되는 경우 jax.devices()를 확인하여 TPU가 활성 상태인지 확인하거나, assert jax.devices()[0].platform == 'tpu'를 사용하여 프로그래매틱 방식으로 확인할 수 있습니다.

RuntimeError: 백엔드 'tpu'를 초기화할 수 없습니다. UNAVAILABLE: TPU 플랫폼을 사용할 수 없습니다.

세부정보

이 런타임 오류 메시지가 표시되거나 TPU VM의 /tmp/tpu_logs/tpu_driver.WARNING에서 W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx이 잘못된 TPU VM 버전을 실행하고 있음을 나타낼 수 있습니다.

현재 JAX 런타임 버전을 실행하고 있는지 확인하고 다시 시도하세요.