Menskalakan workload ML menggunakan Ray
Dokumen ini memberikan detail tentang cara menjalankan beban kerja machine learning (ML) dengan Ray dan JAX di TPU.
Petunjuk ini mengasumsikan bahwa Anda telah menyiapkan lingkungan Ray dan TPU, termasuk lingkungan software yang menyertakan JAX dan paket terkait lainnya. Untuk membuat cluster TPU Ray, ikuti petunjuk di Memulai Google Cloud Cluster GKE dengan TPU untuk KubeRay. Untuk informasi selengkapnya tentang cara menggunakan TPU dengan KubeRay, lihat Menggunakan TPU dengan KubeRay.
Menjalankan workload JAX di TPU host tunggal
Contoh skrip berikut menunjukkan cara menjalankan fungsi JAX di cluster Ray dengan TPU satu host, seperti v6e-4. Jika Anda memiliki TPU multi-host, skrip ini akan berhenti merespons karena model eksekusi multi-pengontrol JAX. Untuk mengetahui informasi selengkapnya tentang cara menjalankan Ray di TPU multi-host, lihat Menjalankan workload JAX di TPU multi-host.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
h = my_function.remote()
print(ray.get(h)) # => 4
Jika Anda terbiasa menjalankan Ray dengan GPU, ada beberapa perbedaan utama saat menggunakan TPU:
- Daripada menetapkan
num_gpus
, Anda menentukanTPU
sebagai resource kustom dan menetapkan jumlah chip TPU. - Anda menentukan TPU menggunakan jumlah chip per node pekerja Ray. Misalnya, jika Anda menggunakan v6e-4, menjalankan fungsi jarak jauh dengan
TPU
yang ditetapkan ke 4 akan menggunakan seluruh host TPU.- Hal ini berbeda dengan cara GPU biasanya berjalan, dengan satu proses per host.
Menetapkan
TPU
ke angka yang bukan 4 tidak direkomendasikan. - Pengecualian: Jika memiliki
v6e-8
atauv5litepod-8
satu host, Anda harus menetapkan nilai ini ke 8.
- Hal ini berbeda dengan cara GPU biasanya berjalan, dengan satu proses per host.
Menetapkan
Menjalankan workload JAX di TPU multi-host
Contoh skrip berikut menunjukkan cara menjalankan fungsi JAX di cluster Ray dengan TPU multi-host. Contoh skrip menggunakan v6e-16.
Jika Anda ingin menjalankan workload di cluster dengan beberapa slice TPU, lihat Mengontrol setiap slice TPU.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]
Jika Anda terbiasa menjalankan Ray dengan GPU, ada beberapa perbedaan utama saat menggunakan TPU:
- Mirip dengan workload PyTorch di GPU:
- Beban kerja JAX di TPU berjalan dalam mode multi-pengontrol, satu program, beberapa data (SPMD).
- Kolektif antarperangkat ditangani oleh framework machine learning.
- Tidak seperti workload PyTorch di GPU, JAX memiliki tampilan global tentang perangkat yang tersedia di cluster.
Menjalankan workload JAX Multislice
Multislice memungkinkan Anda menjalankan workload yang mencakup beberapa slice TPU dalam satu Pod TPU atau di beberapa Pod melalui jaringan pusat data.
Untuk memudahkan, Anda dapat menggunakan paket
ray-tpu
eksperimental untuk menyederhanakan
interaksi Ray dengan slice TPU. Instal ray-tpu
menggunakan pip
:
pip install ray-tpu
Contoh skrip berikut menunjukkan cara menggunakan paket ray-tpu
untuk menjalankan
beban kerja Multislice menggunakan aktor atau tugas Ray:
from ray_tpu import RayTpuManager
import jax
import ray
ray.init()
# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
def get_devices(self):
return jax.device_count()
# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
return jax.device_count()
tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus)
"""
TPU resources:
{'v6e-16': [
RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""
# if using actors
actors = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=MyActor,
multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]
# if using tasks
h = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]
# note - you can also run this without Multislice
h = RayTpuManager.run_task(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]
Resource TPU dan Ray
Ray memperlakukan TPU secara berbeda dari GPU untuk mengakomodasi perbedaan penggunaan. Dalam contoh berikut, ada total sembilan node Ray:
- Node head Ray berjalan di VM
n1-standard-16
. - Node pekerja Ray berjalan di dua TPU
v6e-16
. Setiap TPU terdiri dari empat pekerja.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
Deskripsi kolom penggunaan resource:
CPU
: Jumlah total CPU yang tersedia di cluster.TPU
: Jumlah TPU chip dalam cluster.TPU-v6e-16-head
: ID khusus untuk resource yang sesuai dengan pekerja 0 dari slice TPU. Hal ini penting untuk mengakses setiap slice TPU.memory
: Memori heap pekerja yang digunakan oleh aplikasi Anda.object_store_memory
: Memori yang digunakan saat aplikasi Anda membuat objek di penyimpanan objek menggunakanray.put
dan saat menampilkan nilai dari fungsi jarak jauh.tpu-group-0
dantpu-group-1
: ID unik untuk setiap slice TPU. Hal ini penting untuk menjalankan tugas di slice. Kolom ini ditetapkan ke 4 karena ada 4 host per slice TPU di v6e-16.
Mengontrol setiap slice TPU
Praktik umum dengan Ray dan TPU adalah menjalankan beberapa beban kerja dalam slice TPU yang sama, misalnya, dalam penyesuaian atau penayangan hyperparameter.
Slice TPU memerlukan pertimbangan khusus saat menggunakan Ray untuk penyediaan dan penjadwalan tugas.
Menjalankan workload satu slice
Saat proses Ray dimulai di slice TPU (menjalankan ray start
), proses tersebut akan otomatis mendeteksi informasi tentang slice. Misalnya, topologi, jumlah pekerja dalam slice, dan apakah proses berjalan di pekerja 0.
Saat Anda menjalankan ray status
di TPU v6e-16 dengan nama "my-tpu", outputnya
akan terlihat seperti berikut:
worker 0: {"TPU-v6e-16-head": 1, "TPU": 4, "my-tpu": 1"}
worker 1-3: {"TPU": 4, "my-tpu": 1}
"TPU-v6e-16-head"
adalah label resource untuk pekerja 0 dari slice.
"TPU": 4
menunjukkan bahwa setiap pekerja
memiliki 4 chip. "my-tpu"
adalah nama TPU. Anda dapat menggunakan nilai ini untuk menjalankan beban kerja di TPU dalam slice yang sama, seperti dalam contoh berikut.
Asumsikan Anda ingin menjalankan fungsi berikut pada semua pekerja dalam slice:
@ray.remote()
def my_function():
return jax.device_count()
Anda perlu menargetkan pekerja 0 dari slice, lalu memberi tahu pekerja 0 cara menyiarkan
my_function
ke setiap pekerja dalam slice:
@ray.remote(resources={"TPU-v6e-16-head": 1})
def run_on_pod(remote_fn):
tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> returns my-tpu
num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> returns 4
remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}) # required resources are {"my-tpu": 1, "TPU": 4}
return ray.get([remote_fn.remote() for _ in range(num_hosts)])
h = run_on_pod(my_function).remote() # -> returns a single remote handle
ray.get(h) # -> returns ["16"] * 4
Contoh ini melakukan langkah-langkah berikut:
@ray.remote(resources={"TPU-v6e-16-head": 1})
: Fungsirun_on_pod
berjalan pada pekerja yang memiliki label resourceTPU-v6e-16-head
, yang menargetkan pekerja arbitrer 0.tpu_name = ray.util.accelerators.tpu.get_current_pod_name()
: Mendapatkan nama TPU.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count()
: Mendapatkan jumlah pekerja dalam slice.remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4})
: Tambahkan label resource yang berisi nama TPU dan persyaratan resource"TPU": 4
ke fungsimy_function
.- Karena setiap pekerja di slice TPU memiliki label resource kustom untuk slice tempatnya berada, Ray hanya akan menjadwalkan beban kerja pada pekerja dalam slice TPU yang sama.
- Tindakan ini juga akan mencadangkan 4 pekerja TPU untuk fungsi jarak jauh, sehingga Ray tidak akan menjadwalkan beban kerja TPU lainnya di Pod Ray tersebut.
- Karena
run_on_pod
hanya menggunakan resource logikaTPU-v6e-16-head
,my_function
juga akan berjalan di pekerja 0, tetapi dalam proses yang berbeda.
return ray.get([remote_fn.remote() for _ in range(num_hosts)])
: Panggil fungsimy_function
yang diubah beberapa kali sama dengan jumlah pekerja dan tampilkan hasilnya.h = run_on_pod(my_function).remote()
:run_on_pod
akan dieksekusi secara asinkron dan tidak memblokir proses utama.
Penskalaan otomatis slice TPU
Ray di TPU mendukung penskalaan otomatis pada tingkat perincian slice TPU. Anda dapat mengaktifkan fitur ini menggunakan fitur penyediaan otomatis node GKE (NAP). Anda dapat
menjalankan fitur ini menggunakan Ray Autoscaler dan KubeRay. Jenis resource kepala
digunakan untuk memberi sinyal penskalaan otomatis ke Ray, misalnya, TPU-v6e-32-head
.