Titik Pemeriksaan Otomatis Cloud TPU [Pratinjau Publik]
Ringkasan
Secara historis, saat VM TPU memerlukan pemeliharaan, prosedur akan segera dimulai, tanpa memberi waktu bagi pengguna untuk melakukan tindakan yang mempertahankan progres seperti menyimpan titik pemeriksaan. Hal ini ditunjukkan dalam Gambar 1(a).
Gambar 1. Ilustrasi fitur Autocheckpoint: (a) Tanpa Autocheckpoint, progres pelatihan dari checkpoint terakhir akan hilang saat ada peristiwa pemeliharaan mendatang. (b) Dengan Autocheckpoint, progres pelatihan sejak checkpoint terakhir dapat dipertahankan saat ada peristiwa pemeliharaan mendatang.
Anda dapat menggunakan Autocheckpoint (Gambar 1(b)) untuk mempertahankan progres pelatihan dengan mengonfigurasi kode untuk menyimpan titik pemeriksaan yang tidak terjadwal saat peristiwa pemeliharaan terjadi. Saat peristiwa pemeliharaan terjadi, progres sejak titik kontrol terakhir akan otomatis disimpan. Fitur ini berfungsi pada satu irisan dan Multislice.
Fitur Autocheckpoint berfungsi dengan framework yang dapat menangkap SIGTERM, lalu menyimpan checkpoint. Framework yang didukung mencakup MaxText, Pax, dan JAX dengan Orbax. Dukungan untuk framework tambahan akan diumumkan saat tersedia.
Untuk saat ini, hanya TPU (v2-v4, dan v5e) yang dibuat melalui Cloud TPU API yang dapat menggunakan fitur ini. Dukungan untuk TPU di GKE akan diumumkan saat tersedia.
Menggunakan Titik pemeriksaan otomatis
Fungsi titik pemeriksaan otomatis dinonaktifkan secara default. Saat membuat
TPU atau resource dalam antrean,
Anda dapat mengaktifkannya dengan menambahkan flag --autocheckpoint-enabled
saat menyediakan
TPU.
Dengan mengaktifkan fitur ini, Cloud TPU
akan melakukan langkah-langkah berikut setelah menerima notifikasi tentang
peristiwa pemeliharaan:
- Ambil SIGTERM yang dikirim ke proses menggunakan perangkat TPU,
- Menunggu hingga proses keluar, atau 5 menit telah berlalu, mana saja yang lebih dulu, dan melakukan pemeliharaan pada slice yang terpengaruh.
Perhatikan bahwa infrastruktur yang digunakan oleh Autocheckpoint tidak bergantung pada framework ML. Framework ML apa pun dapat mendukung Autocheckpoint asalkan dapat menangkap sinyal SIGTERM dan memulai proses pembuatan checkpoint.
Dalam kode aplikasi, Anda perlu mengaktifkan kemampuan Autocheckpoint yang disediakan oleh framework ML. Misalnya, di Pax, hal ini berarti mengaktifkan flag command line saat meluncurkan pelatihan (lihat Panduan Memulai autocheckpoint dengan Pax). Di balik layar, framework menyimpan titik pemeriksaan yang tidak terjadwal saat SIGTERM diterima dan VM TPU yang terpengaruh akan menjalani pemeliharaan saat TPU tidak lagi digunakan.
Panduan memulai: Titik henti sementara otomatis dengan MaxText
MaxText adalah "LLM berperforma tinggi, skalabilitas arbitrer, open source, dan teruji dengan baik yang ditulis dalam Python/JAX murni yang menargetkan Cloud TPU". MaxText berisi semua penyiapan yang diperlukan untuk menggunakan fitur Autocheckpoint.
README MaxText menjelaskan dua cara untuk menjalankan MaxText dalam skala besar:
- Menggunakan
multihost_runner.py
, direkomendasikan untuk eksperimen - Menggunakan
multihost_job.job
, direkomendasikan untuk produksi
Saat menggunakan multihost_runner.py
, satu-satunya perubahan yang diperlukan
adalah menetapkan flag autocheckpoint-enabled
saat menyediakan
resource yang diantrekan. Saat menggunakan
multihost_job.py
, satu-satunya perubahan yang diperlukan adalah menentukan
flag command line ENABLE_AUTOCHECKPOINT=true
saat meluncurkan tugas.
Panduan memulai: Titik periksa otomatis dengan Pax pada satu slice
Di bagian ini, kami memberikan contoh cara menyiapkan dan menggunakan Autocheckpoint dengan Pax di satu slice. Dengan penyiapan yang sesuai:
- Titik pemeriksaan akan disimpan saat peristiwa pemeliharaan terjadi.
- Cloud TPU akan melakukan pemeliharaan pada VM TPU yang terpengaruh setelah checkpoint disimpan.
- Setelah Cloud TPU menyelesaikan pemeliharaan, Anda dapat menggunakan VM TPU seperti biasa.
Gunakan flag
autocheckpoint-enabled
saat membuat VM TPU atau resource dalam antrean.Contoh:
PROJECT=your-gcp-project-name ZONE=zone-you-want-to-use NODE_ID=your-node-id ACCELERATOR_TYPE=your-accelerator-type gcloud config set project $PROJECT gcloud config set compute/zone $ZONE
gcloud alpha compute tpus tpu-vm create $NODE_ID \ --accelerator-type $ACCELERATOR_TYPE \ --version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Menginstal Pax di satu slice
Fitur Autocheckpoint berfungsi pada Pax versi >= 1.1.0. Di VM TPU, instal
jax[tpu]
danpaxml
terbaru:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Meluncurkan pelatihan dengan konfigurasi yang sesuai
Contoh berikut menunjukkan cara mengonfigurasi model
LmCloudSpmd2B
untuk menyimpan titik pemeriksaan yang dipicu oleh Autocheckpoint ke bucket Google Cloud Storage:JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Perhatikan dua flag yang diteruskan ke perintah:
jax_fully_async_checkpoint
: Jika flag ini diaktifkan,orbax.checkpoint.AsyncCheckpointer
akan digunakan. ClassAsyncCheckpointer
otomatis menyimpan titik pemeriksaan saat skrip pelatihan menerima sinyal SIGTERM.exit_after_ondemand_checkpoint
: Dengan mengaktifkan tanda ini, proses TPU akan keluar setelah Autocheckpoint berhasil disimpan, yang memicu pemeliharaan untuk segera dilakukan. Jika Anda tidak menggunakan flag ini, pelatihan akan dilanjutkan setelah titik pemeriksaan disimpan dan Cloud TPU akan menunggu waktu tunggu habis (5 menit) sebelum melakukan pemeliharaan yang diperlukan.
Panduan memulai: Titik pemeriksaan otomatis dengan Pax di Multislice
Titik periksa otomatis tidak hanya berfungsi untuk satu slice, tetapi juga untuk Multislice. Bagian ini menjelaskan langkah-langkah yang diperlukan untuk menggunakan Autocheckpoint dengan Multislice.
Tentukan Autocheckpoint selama pembuatan resource yang diantrekan.
Lingkungan Multislice hanya dapat disediakan melalui permintaan resource yang diantrekan. Serupa dengan kasus satu slice, gunakan flag
autocheckpoint-enabled
dalam panggilan untuk membuat resource yang diantrekan.QR_ID=your-qr-id NODE_COUNT=your-node-count ACCELERATOR_TYPE=your-accelerator-type gcloud compute tpus queued-resources create $QR_ID \ --node-count $NODE_COUNT \ --accelerator-type $ACCELERATOR_TYPE \ --runtime-version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Lihat Panduan Pengguna Multislice untuk mengetahui detail tentang semua opsi yang tersedia. Setelah permintaan resource antrean dibuat dan dalam status
ACTIVE
, ikuti langkah berikutnya untuk menjalankan Pax dengan Autocheckpoint.Instal Pax di semua VM di lingkungan Multislice.
Di VM TPU, instal
jax[tpu]
danpaxml
terbaru di semua VM TPU di lingkungan Multislice Anda:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Meluncurkan pelatihan dengan konfigurasi yang sesuai
Contoh ini menunjukkan cara mengonfigurasi model
LmCloudSpmd2B
untuk Autocheckpoint saat berlatih di lingkungan Multislice. Sebelum menjalankan skrip pelatihan, tetapkan DCN_MESH_SHAPE ke [2, 1, 1] seperti yang ditunjukkan dalam kode berikut:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 4, 1] DCN_MESH_SHAPE = [2, 1, 1]
Saat meluncurkan pelatihan, selain flag command line yang dibahas dalam kasus satu slice, tiga flag lainnya diperlukan:
num_hosts
: jumlah total host. Dalam hal ini, nilainya adalah 2.host_index
: indeks host yang meluncurkan pelatihan. Nilai ini bervariasi dari 0 hinggaN-1
denganN
adalah jumlah total host.server_addr
: alamat IP pekerja 0 dari node 0, dengan port yang tidak digunakan (misalnya, 8476). Untuk menemukan informasi ini, gunakanhostname -i
pada pekerja 0 dari node 0.
Titik pemeriksaan otomatis dengan Orbax
Fitur Autocheckpoint tidak terbatas pada MaxText atau Pax. Setiap framework yang dapat menangkap sinyal SIGTERM dan memulai proses checkpointing berfungsi dengan infrastruktur yang disediakan oleh Autocheckpoint. Orbax, namespace yang menyediakan library utilitas umum untuk pengguna JAX, menyediakan kemampuan ini.
Seperti yang dijelaskan dalam dokumentasi Orbax,
kemampuan ini diaktifkan secara default untuk pengguna
orbax.checkpoint.CheckpointManager
. Metode save
yang dipanggil setelah setiap langkah akan otomatis memeriksa apakah peristiwa
pemeliharaan akan segera terjadi, dan jika ya, akan menyimpan titik pemeriksaan meskipun nomor langkah
bukan kelipatan save_interval_steps
.
Dokumentasi GitHub
juga mengilustrasikan cara membuat pelatihan keluar setelah menyimpan
Autocheckpoint, dengan modifikasi pada kode pengguna.