Memecahkan masalah JAX - TPU
Panduan ini memberikan petunjuk ke informasi pemecahan masalah JAX untuk membantu Anda mengidentifikasi dan menyelesaikan masalah yang mungkin Anda alami saat melatih model JAX di Cloud TPU.
Untuk panduan yang lebih umum tentang cara memulai Cloud TPU, lihat panduan memulai JAX.
Masalah JAX umum
Jika Anda mengalami masalah saat mengembangkan model pelatihan atau pelatihan dengan JAX, lihat FAQ JAX.
Untuk error pemrograman umum lainnya yang mungkin Anda temui saat menulis aplikasi pelatihan dengan JAX, lihat Error JAX.
Membuat profil performa JAX
Anda dapat memahami cara resource TPU digunakan menggunakan alat yang dijelaskan dalam Membuat profil performa JAX.
Memecahkan masalah memori
Anda dapat memantau penggunaan memori dengan JAX Device Memory Profiler, tetapi Anda tidak dapat langsung mengelola penggunaannya.
Memory Profiler Perangkat dapat digunakan untuk:
- Cari tahu array dan file yang dapat dieksekusi yang ada di memori TPU pada waktu tertentu, atau
- Melacak kebocoran memori.
Anda tidak dapat menentukan cara memori TPU dialokasikan untuk operasi tertentu. Untuk informasi selengkapnya tentang masalah performa TPU khusus JAX, lihat Catatan Performa untuk menggunakan TPU dengan JAX.
Memecahkan masalah TPU
Bagaimana cara memverifikasi bahwa TPU sedang berjalan?
Detail
Semuanya akan dijalankan di TPU selama JAX tidak mencetak "Tidak ada GPU/TPU yang ditemukan, kembali ke CPU".
Anda dapat memverifikasi bahwa TPU aktif dengan melihat jax.devices()
, tempat
Anda akan melihat beberapa perangkat TPU ditampilkan, atau memverifikasi
secara terprogram dengan: assert jax.devices()[0].platform == 'tpu'
.
RuntimeError: Tidak dapat melakukan inisialisasi backend 'tpu': TIDAK TERSEDIA: Tidak ada Platform TPU yang tersedia.
Detail
Pesan error runtime ini dan/atau menemukan hal berikut di /tmp/tpu_logs/tpu_driver.WARNING
pada VM TPU:
W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx
dapat menunjukkan bahwa Anda menjalankan versi VM TPU yang salah.
Pastikan Anda menjalankan versi runtime JAX saat ini, lalu coba lagi.