Autocheckpoint Cloud TPU [Pratinjau Publik]

Ringkasan

Secara historis, jika VM TPU memerlukan pemeliharaan, prosedur akan segera dimulai, tanpa memberi pengguna waktu untuk melakukan tindakan yang mempertahankan progres seperti menyimpan checkpoint. Hal ini ditunjukkan pada Gambar 1(a).

pos pemeriksaan otomatis

Gambar 1. Ilustrasi fitur Autocheckpoint: (a) Tanpa Autocheckpoint, progres pelatihan dari checkpoint terakhir akan hilang saat ada peristiwa pemeliharaan yang akan datang. (b) Dengan Autocheckpoint, progres pelatihan sejak checkpoint terakhir dapat dipertahankan ketika ada peristiwa pemeliharaan mendatang.

Anda dapat menggunakan Autocheckpoint (Gambar 1(b)) untuk mempertahankan progres pelatihan dengan mengonfigurasi kode untuk menyimpan checkpoint yang tidak terjadwal saat peristiwa pemeliharaan terjadi. Saat terjadi peristiwa pemeliharaan, progres sejak checkpoint terakhir akan disimpan secara otomatis. Fitur ini berfungsi pada irisan tunggal dan Multislice.

Fitur Autocheckpoint berfungsi dengan framework yang dapat menangkap SIGTERM, lalu menyimpan checkpoint. Framework yang didukung meliputi MaxText, Pax, dan JAX dengan Orbax. Dukungan untuk framework tambahan akan diumumkan begitu tersedia.

Untuk saat ini, hanya TPU (v2-v4, dan v5e) yang dibuat melalui Cloud TPU API yang dapat menggunakan fitur ini. Dukungan untuk TPU di GKE akan diumumkan saat tersedia.

Menggunakan Autocheckpoint

Fungsi titik pemeriksaan otomatis dinonaktifkan secara default. Saat membuat TPU atau resource yang diantrekan, Anda dapat mengaktifkannya dengan menambahkan flag --autocheckpoint-enabled saat menyediakan TPU. Setelah fitur tersebut diaktifkan, Cloud TPU akan melakukan langkah-langkah berikut setelah menerima notifikasi peristiwa pemeliharaan:

  1. Tangkap SIGTERM yang dikirim ke proses menggunakan perangkat TPU,
  2. Menunggu hingga proses keluar, atau 5 menit telah berlalu, mana saja yang terjadi lebih dulu, dan melakukan pemeliharaan pada slice yang terpengaruh.

Perlu diperhatikan bahwa infrastruktur yang digunakan oleh Autocheckpoint tidak bergantung pada framework ML. Semua framework ML dapat mendukung Autocheckpoint asalkan dapat menangkap sinyal SIGTERM dan memulai proses checkpoint.

Dalam kode aplikasi, Anda harus mengaktifkan kemampuan Autocheckpoint yang disediakan oleh framework ML. Misalnya, pada Pax, hal ini berarti mengaktifkan tanda command line saat meluncurkan pelatihan (lihat Panduan Memulai autocheckpoint dengan Pax). Di balik layar, framework menyimpan checkpoint yang tidak terjadwal saat SIGTERM diterima dan VM TPU yang terpengaruh menjalani pemeliharaan saat TPU tidak lagi digunakan.

Panduan memulai: Checkpoint otomatis dengan MaxText

MaxText adalah "LLM berperforma tinggi, skalabel secara arbitrer, open source, dan telah teruji dengan baik yang ditulis dalam Python/JAX murni yang menargetkan Cloud TPU". MaxText berisi semua penyiapan yang diperlukan untuk menggunakan fitur Autocheckpoint.

README MaxText menjelaskan dua cara untuk menjalankan MaxText dalam skala besar:

Saat menggunakan multihost_runner.py, satu-satunya perubahan yang diperlukan adalah menyetel flag autocheckpoint-enabled saat menyediakan resource yang diantrekan. Saat menggunakan multihost_job.py, satu-satunya perubahan yang diperlukan adalah menentukan flag command line ENABLE_AUTOCHECKPOINT=true saat meluncurkan tugas.

Panduan memulai: Autocheckpoint dengan Pax pada satu irisan

Di bagian ini, kami memberikan contoh cara menyiapkan dan menggunakan Autocheckpoint dengan Pax pada satu irisan. Dengan penyiapan yang sesuai:

  • Checkpoint akan disimpan saat terjadi peristiwa pemeliharaan.
  • Cloud TPU akan melakukan pemeliharaan pada VM TPU yang terpengaruh setelah checkpoint disimpan.
  • Saat Cloud TPU menyelesaikan pemeliharaan, Anda dapat menggunakan VM TPU seperti biasa.
  1. Gunakan flag autocheckpoint-enabled saat membuat VM TPU atau resource yang diantrekan.

    Contoh:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    
  2. Menginstal Pax dalam satu irisan

    Fitur Autocheckpoint berfungsi pada versi Pax >= 1.1.0. Di VM TPU, instal jax[tpu] dan paxml terbaru:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Meluncurkan pelatihan dengan konfigurasi yang sesuai

    Contoh berikut menunjukkan cara mengonfigurasi model LmCloudSpmd2B untuk menyimpan checkpoint yang dipicu oleh Autocheckpoint ke bucket Google Cloud Storage:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
    

    Perhatikan dua flag yang diteruskan ke perintah:

    • jax_fully_async_checkpoint: Dengan mengaktifkan tanda ini, orbax.checkpoint.AsyncCheckpointer akan digunakan. Class AsyncCheckpointer akan otomatis menyimpan checkpoint saat skrip pelatihan menerima sinyal SIGTERM.
    • exit_after_ondemand_checkpoint: Dengan mengaktifkan flag ini, proses TPU akan keluar setelah Autocheckpoint berhasil disimpan, yang memicu pemeliharaan yang akan segera dilakukan. Jika Anda tidak menggunakan flag ini, pelatihan akan dilanjutkan setelah checkpoint disimpan dan Cloud TPU akan menunggu hingga waktu tunggu (5 menit) sebelum melakukan pemeliharaan yang diperlukan.

Panduan memulai: Autocheckpoint dengan Pax di Multislice

Autocheckpoint tidak hanya berfungsi untuk slice tunggal, tetapi juga untuk Multislice. Bagian ini menjelaskan langkah-langkah yang diperlukan untuk menggunakan Autocheckpoint dengan Multislice.

  1. Menentukan Autocheckpoint selama pembuatan resource yang diantrekan.

    Lingkungan Multislice hanya dapat disediakan melalui permintaan resource yang diantrekan. Serupa dengan kasus irisan tunggal, gunakan tanda autocheckpoint-enabled dalam panggilan untuk membuat resource yang diantrekan.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud alpha compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    

    Lihat Panduan Pengguna Multislice untuk mengetahui detail tentang semua opsi yang tersedia. Setelah permintaan resource yang diantrekan dibuat dan dalam status ACTIVE, ikuti langkah berikutnya untuk menjalankan Pax dengan Autocheckpoint.

  2. Menginstal Pax pada semua VM di lingkungan Multislice.

    Di VM TPU, instal jax[tpu] dan paxml terbaru di semua VM TPU di lingkungan Multislice Anda:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Meluncurkan pelatihan dengan konfigurasi yang sesuai

    Contoh ini menunjukkan cara mengonfigurasi model LmCloudSpmd2B untuk Autocheckpoint saat melakukan pelatihan di lingkungan Multislice. Sebelum menjalankan skrip pelatihan, tetapkan DCN_MESH_SHAPE ke [2, 1, 1] seperti yang ditunjukkan pada kode berikut:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]
    

    Saat meluncurkan pelatihan, selain tanda command line yang dibahas dalam kasus slice tunggal, diperlukan tiga lagi:

    • num_hosts: jumlah total host. Dalam hal ini, nilainya adalah 2.
    • host_index: indeks host yang meluncurkan pelatihan. Nilai ini bervariasi dari 0 hingga N-1 dengan N adalah jumlah total host.
    • server_addr: alamat IP pekerja 0 node 0, dengan port yang tidak digunakan (misalnya, 8476). Untuk menemukan informasi ini, gunakan hostname -i pada pekerja 0 dari node 0.

Pos pemeriksaan otomatis dengan Orbax

Fitur Autocheckpoint tidak terbatas pada MaxText atau Pax. Setiap framework yang dapat menangkap sinyal SIGTERM dan memulai proses checkpoint akan berfungsi dengan infrastruktur yang disediakan oleh Autocheckpoint. Orbax, namespace yang menyediakan library utilitas umum untuk pengguna JAX, menyediakan kemampuan ini.

Seperti yang dijelaskan dalam dokumentasi Oracle, kemampuan ini diaktifkan secara default untuk pengguna orbax.checkpoint.CheckpointManager. Metode save yang dipanggil setelah setiap langkah akan otomatis memeriksa apakah peristiwa pemeliharaan akan segera terjadi. Jika demikian, metode ini akan menyimpan checkpoint meskipun nomor langkahnya bukan kelipatan save_interval_steps. Dokumentasi GitHub juga menggambarkan cara keluar dari pelatihan setelah menyimpan Autocheckpoint, dengan modifikasi pada kode pengguna.