Menskalakan workload ML menggunakan Ray

Pengantar

Alat Cloud TPU Ray menggabungkan Cloud TPU API dan Ray Jobs dengan tujuan meningkatkan pengalaman pengembangan pengguna di Cloud TPU. Panduan pengguna ini memberikan contoh minimal tentang cara menggunakan Ray dengan Cloud TPU. Contoh ini tidak dimaksudkan untuk digunakan dalam layanan produksi dan hanya untuk tujuan ilustrasi.

Apa saja yang disertakan dalam alat ini?

Untuk memudahkan Anda, alat ini menyediakan:

  • Abstraksi umum yang menyembunyikan boilerplate untuk tindakan TPU umum
  • Contoh mainan yang dapat Anda buat fork untuk alur kerja dasar Anda sendiri

Secara khusus:

  • tpu_api.py: Wrapper Python untuk operasi TPU dasar menggunakan Cloud TPU API.
  • tpu_controller.py: Representasi kelas TPU. Pada dasarnya, ini adalah wrapper untuk tpu_api.py.
  • ray_tpu_controller.py: Pengontrol TPU dengan fungsi Ray. Ini memisahkan boilerplate untuk cluster Ray dan Ray Jobs.
  • run_basic_jax.py: Contoh dasar yang menunjukkan cara menggunakan RayTpuController untuk print(jax.device_count()).
  • run_hp_search.py: Contoh dasar yang menunjukkan bagaimana Ray Tune dapat digunakan dengan JAX/Flax di MNIST.
  • run_pax_autoresume.py: Contoh yang menunjukkan cara menggunakan RayTpuController untuk pelatihan fault-tolerant menggunakan PAX sebagai contoh beban kerja.

Menyiapkan node head cluster Ray

Salah satu cara dasar untuk menggunakan Ray dengan Pod TPU adalah dengan menyiapkan Pod TPU sebagai cluster Ray. Membuat VM CPU terpisah sebagai VM koordinator adalah cara alami untuk melakukannya. Gambar berikut menunjukkan contoh konfigurasi cluster Ray:

Contoh konfigurasi gugus Ray

Perintah berikut menunjukkan cara menyiapkan cluster Ray menggunakan Google Cloud CLI:

$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ...
$ gcloud compute ssh my_tpu_admin

$ (vm) pip3 install ray[default]
$ (vm) ray start --head --port=6379 --num-cpus=0
...
# (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP)
$ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"

Untuk memudahkan Anda, kami juga menyediakan skrip dasar untuk membuat VM koordinator dan men-deploy konten folder ini ke VM koordinator Anda. Untuk kode sumber, lihat create_cpu.sh dan deploy.sh.

Skrip ini menetapkan beberapa nilai default:

  • create_cpu.sh akan membuat VM bernama $USER-admin dan akan menggunakan project dan zona apa pun yang ditetapkan ke default gcloud config Anda. Jalankan gcloud config list untuk melihat default tersebut.
  • create_cpu.sh secara default mengalokasikan ukuran boot disk sebesar 200 GB.
  • deploy.sh mengasumsikan bahwa nama VM Anda adalah $USER-admin. Jika Anda mengubah nilai tersebut di create_cpu.sh, pastikan untuk mengubahnya di deploy.sh.

Untuk menggunakan skrip praktis:

  1. Clone repositori GitHub ke komputer lokal Anda dan masukkan folder ray_tpu:

    $ git clone https://github.com/tensorflow/tpu.git
    $ cd tpu/tools/ray_tpu/
    
  2. Jika Anda tidak memiliki akun layanan khusus untuk administrasi TPU (sangat direkomendasikan), siapkan:

    $ ./create_tpu_service_account.sh
    
  3. Buat VM koordinator:

    $ ./create_cpu.sh
    

    Skrip ini menginstal dependensi pada VM menggunakan skrip startup dan otomatis melakukan pemblokiran hingga skrip startup selesai.

  4. Deploy kode lokal ke VM koordinator:

    $ ./deploy.sh
    
  5. SSH ke VM:

    $ gcloud compute ssh $USER-admin -- -L8265:localhost:8265
    

    Penerusan port diaktifkan di sini karena Ray akan otomatis memulai dasbor di port 8265. Dari komputer tempat Anda menjalankan SSH ke VM koordinator, Anda dapat mengakses dasbor ini di http://127.0.0.1:8265/.

  6. Jika Anda melewati langkah 0, siapkan kredensial gcloud di dalam CPU VM:

    $ (vm) gcloud auth login --update-adc
    

    Langkah ini menetapkan info project ID dan mengizinkan Cloud TPU API untuk berjalan di VM koordinator.

  7. Persyaratan penginstalan:

    $ (vm) pip3 install -r src/requirements.txt
    
  8. Mulai Ray di VM koordinator, dan VM koordinator akan menjadi node kepala cluster Ray:

    $ (vm) ray start --head --port=6379 --num-cpus=0
    

Contoh penggunaan

Contoh JAX dasar

run_basic_jax.py adalah contoh minimal yang menunjukkan cara menggunakan lingkungan runtime Ray Jobs dan Ray di cluster Ray dengan VM TPU untuk menjalankan beban kerja JAX.

Untuk framework ML yang kompatibel dengan Cloud TPU yang menggunakan model pemrograman multi-pengontrol, seperti JAX dan PyTorch/XLA PJRT, Anda harus menjalankan setidaknya satu proses per host. Untuk mengetahui informasi selengkapnya, lihat Model pemrograman multiproses. Dalam praktiknya, hal ini mungkin terlihat seperti ini:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"

Jika memiliki lebih dari ~16 host, seperti v4-128, Anda akan mengalami masalah skalabilitas SSH dan perintah Anda mungkin harus diubah menjadi:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8

Hal ini dapat menjadi penghambat kecepatan developer jika my_bug_free_python_code berisi bug. Salah satu cara untuk mengatasi masalah ini adalah dengan menggunakan orchestrator seperti Kubernetes atau Ray. Ray menyertakan konsep lingkungan runtime yang, jika diterapkan, akan men-deploy kode dan dependensi saat aplikasi Ray dijalankan.

Dengan menggabungkan lingkungan runtime Ray dengan cluster Ray dan Ray Jobs, Anda dapat mengabaikan siklus SCP/SSH. Dengan asumsi Anda mengikuti contoh di atas, Anda dapat menjalankan ini dengan:

$ python3 legacy/run_basic_jax.py

Outputnya mirip dengan hal berikut ini:

2023-03-01 22:12:10,065   INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379...
2023-03-01 22:12:10,072   INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265
W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu...
Creating TPU:  $USER-ray-test
Request:  {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}}
Create TPU operation still running...
...
Create TPU operation complete.
I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster...
…
I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host.
2023-03-01 22:15:18,010   INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip.
2023-03-01 22:15:18,010   INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'.
2023-03-01 22:15:18,080   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload.
I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs.
...
I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED
I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089   21286 credentials_generic.cc:35]            Could not get HOME environment variable.
8

I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED
I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499    8561 credentials_generic.cc:35]            Could not get HOME environment variable.
8

Pelatihan fault-tolerant

Contoh ini menunjukkan cara menggunakan RayTpuController untuk mengimplementasikan pelatihan fault-tolerant. Untuk contoh ini, kami melatih LLM sederhana di PAX pada v4-16. Namun, perhatikan bahwa Anda dapat mengganti beban kerja PAX ini dengan beban kerja yang berjalan lama lainnya. Untuk kode sumber, lihat run_pax_autoresume.py.

Untuk menjalankan contoh ini:

  1. Clone paxml ke VM koordinator Anda:

    $ git clone https://github.com/google/paxml.git
    

    Untuk menunjukkan kemudahan penggunaan yang disediakan oleh Ray Runtime Environment untuk membuat dan men-deploy perubahan JAX, contoh ini mengharuskan Anda memodifikasi PAX.

  2. Tambahkan konfigurasi eksperimen baru:

    $ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py
    
    @experiment_registry.register
    class TestModel(LmCloudSpmd2BLimitSteps):
    ICI_MESH_SHAPE = [1, 4, 2]
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ
    
    def task(self) -> tasks_lib.SingleTask.HParams:
      task_p = super().task()
      task_p.train.num_train_steps = 1000
      task_p.train.save_interval_steps = 100
      return task_p
    EOT
    
  3. Jalankan run_pax_autoresume.py:

    $ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
    
  4. Saat beban kerja berjalan, bereksperimenlah dengan apa yang terjadi jika Anda menghapus TPU, secara default, bernama $USER-tpu-ray:

    $ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b
    

    Ray akan mendeteksi bahwa TPU tidak berfungsi dengan pesan berikut:

    I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata.
    W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu...
    2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    

    Tugas tersebut juga akan otomatis membuat ulang TPU VM dan memulai ulang tugas pelatihan agar dapat melanjutkan pelatihan dari checkpoint terbaru (langkah 200 dalam contoh ini):

    I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting...
    I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`...
    I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`...
    I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally
    

Contoh ini menampilkan penggunaan Ray Tune dari Ray AIR ke penyesuaian hyperparameter MNIST dari JAX/FLAX. Untuk kode sumber, lihat run_hp_search.py.

Untuk menjalankan contoh ini:

  1. Instal persyaratan:

    $ pip3 install -r src/tune/requirements.txt
    
  2. Jalankan run_hp_search.py:

    $ python3 src/tune/run_hp_search.py
    

    Outputnya mirip dengan hal berikut ini:

    Number of trials: 3/3 (3 TERMINATED)
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    | Trial name                  | status     | loc               |   learning_rate |   momentum |    acc |   iter |   total time (s) |
    |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------|
    | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 |     1.15258e-09 |   0.897988 | 0.0982 |      3 |          82.4525 |
    | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 |     0.000219523 |   0.825463 | 0.1009 |      3 |          73.1168 |
    | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 |     1.08035e-08 |   0.660416 | 0.098  |      3 |          71.6813 |
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    
    2023-03-02 21:50:47,378   INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop).
    ...
    

Pemecahan masalah

Node head ray tidak dapat terhubung

Jika Anda menjalankan beban kerja yang membuat/menghapus siklus proses TPU, terkadang hal ini tidak memutuskan koneksi host TPU dari cluster Ray. Error ini mungkin muncul sebagai error gRPC yang menandakan bahwa node Ray head tidak dapat terhubung ke kumpulan alamat IP.

Oleh karena itu, Anda mungkin perlu menghentikan sesi ray (ray stop) dan memulai ulang sesi tersebut (ray start --head --port=6379 --num-cpus=0).

Ray Job gagal secara langsung tanpa output log

PAX bersifat eksperimental dan contoh ini dapat rusak karena dependensi pip. Jika ini terjadi, Anda mungkin melihat sesuatu seperti ini:

I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs.
I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures.
I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2})
I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED.
2023-03-03 20:51:38,641   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
2023-03-03 20:51:38,706   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.

Untuk mengetahui penyebab utama error ini, Anda dapat membuka http://127.0.0.1:8265/ dan melihat dasbor untuk tugas yang berjalan/gagal, yang akan memberikan informasi selengkapnya. runtime_env_agent.log menampilkan semua informasi error yang terkait dengan penyiapan runtime_env, misalnya:

60    INFO: pip is looking at multiple versions of  to determine which version is compatible with other requirements. This could take a while.
61    INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while.
62    ERROR: Cannot install paxml because these package versions have conflicting dependencies.
63
64    The conflict is caused by:
65        praxis 0.3.0 depends on t5x
66        praxis 0.2.1 depends on t5x
67        praxis 0.2.0 depends on t5x
68        praxis 0.1 depends on t5x
69
70    To fix this you could try to:
71    1. loosen the range of package versions you've specified
72    2. remove package versions to allow pip attempt to solve the dependency conflict
73
74    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts