Melatih di TPU host tunggal menggunakan Pax


Dokumen ini memberikan pengantar singkat tentang cara menggunakan Pax di TPU host tunggal (v2-8, v3-8, v4-8).

Pax adalah framework untuk mengonfigurasi dan menjalankan eksperimen machine learning di atas JAX. Pax berfokus pada menyederhanakan ML dalam skala besar dengan berbagi komponen infrastruktur dengan framework ML yang ada dan memanfaatkan library pemodelan Praxis untuk modularitas.

Tujuan

  • Menyiapkan resource TPU untuk pelatihan
  • Menginstal Pax di TPU host tunggal
  • Melatih model SPMD berbasis transformer menggunakan Pax

Sebelum memulai

Jalankan perintah berikut untuk mengonfigurasi gcloud agar dapat menggunakan project Cloud TPU Anda dan menginstal komponen yang diperlukan untuk melatih model yang menjalankan Pax di TPU host tunggal.

Menginstal Google Cloud CLI.

Google Cloud CLI berisi alat dan library untuk berinteraksi dengan produk dan layanan Google Cloud CLI. Jika Anda belum menginstalnya sebelumnya, instal sekarang menggunakan petunjuk di Menginstal Google Cloud CLI.

Mengonfigurasi perintah gcloud

(Jalankan gcloud auth list untuk melihat akun yang tersedia).

$ gcloud config set account account

$ gcloud config set project project-id

Mengaktifkan Cloud TPU API

Aktifkan Cloud TPU API menggunakan perintah gcloud berikut di Cloud Shell. (Anda juga dapat mengaktifkannya dari Konsol Google Cloud).

$ gcloud services enable tpu.googleapis.com

Jalankan perintah berikut untuk membuat identitas layanan (akun layanan).

$ gcloud beta services identity create --service tpu.googleapis.com

Membuat VM TPU

Dengan VM Cloud TPU, model dan kode Anda berjalan langsung di VM TPU. Anda dapat menggunakan SSH langsung ke VM TPU. Anda dapat menjalankan kode arbitrer, menginstal paket, melihat log, dan men-debug kode langsung di VM TPU.

Buat VM TPU dengan menjalankan perintah berikut dari Cloud Shell atau terminal komputer tempat Google Cloud CLI diinstal.

Tetapkan zone berdasarkan ketersediaan dalam kontrak Anda, lihat Region dan Zona TPU jika diperlukan.

Tetapkan variabel accelerator-type ke v2-8, v3-8, atau v4-8.

Tetapkan variabel version ke tpu-vm-base untuk versi TPU v2 dan v3 atau tpu-vm-v4-base untuk TPU v4.

$ gcloud compute tpus tpu-vm create tpu-name \
--zone zone \
--accelerator-type accelerator-type \
--version version

Menghubungkan ke VM Google Cloud TPU

Lakukan SSH ke VM TPU menggunakan perintah berikut:

$ gcloud compute tpus tpu-vm ssh tpu-name --zone zone

Saat Anda login ke VM, perintah shell akan berubah dari username@projectname menjadi username@vm-name:

Menginstal Pax di VM Google Cloud TPU

Instal Pax, JAX, dan libtpu di VM TPU menggunakan perintah berikut:

(vm)$ python3 -m pip install -U pip \
python3 -m pip install paxml jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Pemeriksaan sistem

Uji apakah semuanya telah diinstal dengan benar dengan memeriksa apakah JAX melihat core TPU:

(vm)$ python3 -c "import jax; print(jax.device_count())"

Jumlah core TPU ditampilkan, yang seharusnya 8 jika Anda menggunakan v2-8 atau v3-8, atau 4 jika Anda menggunakan v4-8.

Menjalankan kode Pax di VM TPU

Sekarang Anda dapat menjalankan kode Pax yang diinginkan. Contoh lm_cloud adalah tempat yang tepat untuk mulai menjalankan model di Pax. Misalnya, perintah berikut melatih model bahasa SPMD berbasis transformer parameter 2B pada data sintetis.

Perintah berikut menampilkan output pelatihan untuk model bahasa SPMD. Proses ini melatih 300 langkah dalam waktu sekitar 20 menit.

(vm)$ python3 .local/lib/python3.10/site-packages/paxml/main.py  --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps --job_log_dir=job_log_dir

Pada slice v4-8, output harus menyertakan:

Kerugian dan waktu langkah

tensor ringkasan pada langkah=step_# loss = loss
tensor ringkasan pada langkah=step_# Langkah per detik x

Pembersihan

Agar tidak perlu membayar biaya pada akun Google Cloud Anda untuk resource yang digunakan dalam tutorial ini, hapus project yang berisi resource tersebut, atau simpan project dan hapus setiap resource.

Setelah selesai menggunakan VM TPU, ikuti langkah-langkah berikut untuk membersihkan resource.

Putuskan koneksi dari instance Compute Engine, jika Anda belum melakukannya:

(vm)$ exit

Hapus Cloud TPU Anda.

$ gcloud compute tpus tpu-vm delete tpu-name  --zone zone

Langkah selanjutnya

Untuk mengetahui informasi selengkapnya tentang Cloud TPU, lihat: