Menskalakan workload ML menggunakan Ray

Dokumen ini memberikan detail tentang cara menjalankan beban kerja machine learning (ML) dengan Ray dan JAX di TPU.

Petunjuk ini mengasumsikan bahwa Anda telah menyiapkan lingkungan Ray dan TPU, termasuk lingkungan software yang menyertakan JAX dan paket terkait lainnya. Untuk membuat cluster TPU Ray, ikuti petunjuk di Memulai Google Cloud Cluster GKE dengan TPU untuk KubeRay. Untuk informasi selengkapnya tentang cara menggunakan TPU dengan KubeRay, lihat Menggunakan TPU dengan KubeRay.

Menjalankan workload JAX di TPU host tunggal

Contoh skrip berikut menunjukkan cara menjalankan fungsi JAX di cluster Ray dengan TPU satu host, seperti v6e-4. Jika Anda memiliki TPU multi-host, skrip ini akan berhenti merespons karena model eksekusi multi-pengontrol JAX. Untuk mengetahui informasi selengkapnya tentang cara menjalankan Ray di TPU multi-host, lihat Menjalankan workload JAX di TPU multi-host.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

h = my_function.remote()
print(ray.get(h)) # => 4

Jika Anda terbiasa menjalankan Ray dengan GPU, ada beberapa perbedaan utama saat menggunakan TPU:

  • Daripada menetapkan num_gpus, Anda menentukan TPU sebagai resource kustom dan menetapkan jumlah chip TPU.
  • Anda menentukan TPU menggunakan jumlah chip per node pekerja Ray. Misalnya, jika Anda menggunakan v6e-4, menjalankan fungsi jarak jauh dengan TPU yang ditetapkan ke 4 akan menggunakan seluruh host TPU.
    • Hal ini berbeda dengan cara GPU biasanya berjalan, dengan satu proses per host. Menetapkan TPU ke angka yang bukan 4 tidak direkomendasikan.
    • Pengecualian: Jika memiliki v6e-8 atau v5litepod-8 satu host, Anda harus menetapkan nilai ini ke 8.

Menjalankan workload JAX di TPU multi-host

Contoh skrip berikut menunjukkan cara menjalankan fungsi JAX di cluster Ray dengan TPU multi-host. Contoh skrip menggunakan v6e-16.

Jika Anda ingin menjalankan workload di cluster dengan beberapa slice TPU, lihat Mengontrol setiap slice TPU.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]

Jika Anda terbiasa menjalankan Ray dengan GPU, ada beberapa perbedaan utama saat menggunakan TPU:

  • Mirip dengan workload PyTorch di GPU:
  • Tidak seperti workload PyTorch di GPU, JAX memiliki tampilan global tentang perangkat yang tersedia di cluster.

Menjalankan workload JAX Multislice

Multislice memungkinkan Anda menjalankan workload yang mencakup beberapa slice TPU dalam satu Pod TPU atau di beberapa Pod melalui jaringan pusat data.

Untuk memudahkan, Anda dapat menggunakan paket ray-tpu eksperimental untuk menyederhanakan interaksi Ray dengan slice TPU. Instal ray-tpu menggunakan pip:

pip install ray-tpu

Contoh skrip berikut menunjukkan cara menggunakan paket ray-tpu untuk menjalankan beban kerja Multislice menggunakan aktor atau tugas Ray:

from ray_tpu import RayTpuManager
import jax
import ray

ray.init()

# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
    def get_devices(self):
        return jax.device_count()

# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
    return jax.device_count()

tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus) 
"""
TPU resources:
{'v6e-16': [
    RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
    RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""

# if using actors
actors = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=MyActor,
    multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]

# if using tasks
h = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]

# note - you can also run this without Multislice
h = RayTpuManager.run_task(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]

Resource TPU dan Ray

Ray memperlakukan TPU secara berbeda dari GPU untuk mengakomodasi perbedaan penggunaan. Dalam contoh berikut, ada total sembilan node Ray:

  • Node head Ray berjalan di VM n1-standard-16.
  • Node pekerja Ray berjalan di dua TPU v6e-16. Setiap TPU terdiri dari empat pekerja.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

Deskripsi kolom penggunaan resource:

  • CPU: Jumlah total CPU yang tersedia di cluster.
  • TPU: Jumlah TPU chip dalam cluster.
  • TPU-v6e-16-head: ID khusus untuk resource yang sesuai dengan pekerja 0 dari slice TPU. Hal ini penting untuk mengakses setiap slice TPU.
  • memory: Memori heap pekerja yang digunakan oleh aplikasi Anda.
  • object_store_memory: Memori yang digunakan saat aplikasi Anda membuat objek di penyimpanan objek menggunakan ray.put dan saat menampilkan nilai dari fungsi jarak jauh.
  • tpu-group-0 dan tpu-group-1: ID unik untuk setiap slice TPU. Hal ini penting untuk menjalankan tugas di slice. Kolom ini ditetapkan ke 4 karena ada 4 host per slice TPU di v6e-16.

Mengontrol setiap slice TPU

Praktik umum dengan Ray dan TPU adalah menjalankan beberapa beban kerja dalam slice TPU yang sama, misalnya, dalam penyesuaian atau penayangan hyperparameter.

Slice TPU memerlukan pertimbangan khusus saat menggunakan Ray untuk penyediaan dan penjadwalan tugas.

Menjalankan workload satu slice

Saat proses Ray dimulai di slice TPU (menjalankan ray start), proses tersebut akan otomatis mendeteksi informasi tentang slice. Misalnya, topologi, jumlah pekerja dalam slice, dan apakah proses berjalan di pekerja 0.

Saat Anda menjalankan ray status di TPU v6e-16 dengan nama "my-tpu", outputnya akan terlihat seperti berikut:

worker 0: {"TPU-v6e-16-head": 1, "TPU": 4, "my-tpu": 1"}
worker 1-3: {"TPU": 4, "my-tpu": 1}

"TPU-v6e-16-head" adalah label resource untuk pekerja 0 dari slice. "TPU": 4 menunjukkan bahwa setiap pekerja memiliki 4 chip. "my-tpu" adalah nama TPU. Anda dapat menggunakan nilai ini untuk menjalankan beban kerja di TPU dalam slice yang sama, seperti dalam contoh berikut.

Asumsikan Anda ingin menjalankan fungsi berikut pada semua pekerja dalam slice:

@ray.remote()
def my_function():
    return jax.device_count()

Anda perlu menargetkan pekerja 0 dari slice, lalu memberi tahu pekerja 0 cara menyiarkan my_function ke setiap pekerja dalam slice:

@ray.remote(resources={"TPU-v6e-16-head": 1})
def run_on_pod(remote_fn):
    tpu_name = ray.util.accelerators.tpu.get_current_pod_name()  # -> returns my-tpu
    num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> returns 4
    remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}) # required resources are {"my-tpu": 1, "TPU": 4}
    return ray.get([remote_fn.remote() for _ in range(num_hosts)])

h = run_on_pod(my_function).remote() # -> returns a single remote handle
ray.get(h) # -> returns ["16"] * 4

Contoh ini melakukan langkah-langkah berikut:

  • @ray.remote(resources={"TPU-v6e-16-head": 1}): Fungsi run_on_pod berjalan pada pekerja yang memiliki label resource TPU-v6e-16-head, yang menargetkan pekerja arbitrer 0.
  • tpu_name = ray.util.accelerators.tpu.get_current_pod_name(): Mendapatkan nama TPU.
  • num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count(): Mendapatkan jumlah pekerja dalam slice.
  • remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}): Tambahkan label resource yang berisi nama TPU dan persyaratan resource "TPU": 4 ke fungsi my_function.
    • Karena setiap pekerja di slice TPU memiliki label resource kustom untuk slice tempatnya berada, Ray hanya akan menjadwalkan beban kerja pada pekerja dalam slice TPU yang sama.
    • Tindakan ini juga akan mencadangkan 4 pekerja TPU untuk fungsi jarak jauh, sehingga Ray tidak akan menjadwalkan beban kerja TPU lainnya di Pod Ray tersebut.
    • Karena run_on_pod hanya menggunakan resource logika TPU-v6e-16-head, my_function juga akan berjalan di pekerja 0, tetapi dalam proses yang berbeda.
  • return ray.get([remote_fn.remote() for _ in range(num_hosts)]): Panggil fungsi my_function yang diubah beberapa kali sama dengan jumlah pekerja dan tampilkan hasilnya.
  • h = run_on_pod(my_function).remote(): run_on_pod akan dieksekusi secara asinkron dan tidak memblokir proses utama.

Penskalaan otomatis slice TPU

Ray di TPU mendukung penskalaan otomatis pada tingkat perincian slice TPU. Anda dapat mengaktifkan fitur ini menggunakan fitur penyediaan otomatis node GKE (NAP). Anda dapat menjalankan fitur ini menggunakan Ray Autoscaler dan KubeRay. Jenis resource kepala digunakan untuk memberi sinyal penskalaan otomatis ke Ray, misalnya, TPU-v6e-32-head.