Scala i workload ML utilizzando Ray
Questo documento fornisce dettagli su come eseguire carichi di lavoro di machine learning (ML) con Ray e JAX su TPU.
Queste istruzioni presuppongono che tu abbia già configurato un ambiente Ray e TPU, incluso un ambiente software che includa JAX e altri pacchetti correlati. Per creare un cluster Ray TPU, segui le istruzioni riportate in Avvia Google Cloud un cluster GKE con TPU per KubeRay. Per ulteriori informazioni sull'utilizzo delle TPU con KubeRay, consulta Utilizzare le TPU con KubeRay.
Esegui un carico di lavoro JAX su una TPU con un solo host
Lo script di esempio riportato di seguito mostra come eseguire una funzione JAX su un cluster Ray con una TPU a un solo host, ad esempio una v6e-4. Se hai una TPU multi-host, questo script non risponde a causa del modello di esecuzione con più controller di JAX. Per ulteriori informazioni sull'esecuzione di Ray su una TPU multi-host, consulta Eseguire un carico di lavoro JAX su una TPU multi-host.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
h = my_function.remote()
print(ray.get(h)) # => 4
Se sei abituato a eseguire Ray con GPU, esistono alcune differenze chiave quando utilizzi le TPU:
- Anziché impostare
num_gpus
, specificaTPU
come risorsa personalizzata e imposta il numero di chip TPU. - Specifica la TPU utilizzando il numero di chip per nodo worker Ray. Ad esempio, se utilizzi una v6e-4, l'esecuzione di una funzione remota con
TPU
impostato su 4 consuma l'intero host TPU.- Questo è diverso dal modo in cui in genere vengono eseguite le GPU, con un processo per host.
L'impostazione di
TPU
su un numero diverso da 4 non è consigliata. - Eccezione: se hai un
v6e-8
o unv5litepod-8
a host singolo, devi impostare questo valore su 8.
- Questo è diverso dal modo in cui in genere vengono eseguite le GPU, con un processo per host.
L'impostazione di
Esegui un carico di lavoro JAX su una TPU multi-host
Lo script di esempio seguente mostra come eseguire una funzione JAX su un cluster Ray con una TPU multi-host. Lo script di esempio utilizza una versione v6e-16.
Se vuoi eseguire il tuo carico di lavoro su un cluster con più sezioni TPU, consulta Controllare le singole sezioni TPU.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]
Se sei abituato a eseguire Ray con GPU, esistono alcune differenze chiave quando utilizzi le TPU:
- Simile ai carichi di lavoro di PyTorch sulle GPU:
- I carichi di lavoro JAX sulle TPU vengono eseguiti in un modello con più controller, un singolo programma e più dati (SPMD).
- I gruppi tra dispositivi sono gestiti dal framework di machine learning.
- A differenza dei carichi di lavoro PyTorch sulle GPU, JAX ha una vista globale dei dispositivi disponibili nel cluster.
Esegui un carico di lavoro JAX multislice
Multislice ti consente di eseguire carichi di lavoro che coprono più slice TPU all'interno di un singolo pod di TPU o in più pod sulla rete del data center.
Per comodità, puoi utilizzare il pacchetto sperimentale
ray-tpu
per semplificare
le interazioni di Ray con i sezioni TPU. Installa ray-tpu
utilizzando pip
:
pip install ray-tpu
Lo script di esempio seguente mostra come utilizzare il pacchetto ray-tpu
per eseguire carichi di lavoro multislice utilizzando gli attori o le attività Ray:
from ray_tpu import RayTpuManager
import jax
import ray
ray.init()
# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
def get_devices(self):
return jax.device_count()
# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
return jax.device_count()
tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus)
"""
TPU resources:
{'v6e-16': [
RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""
# if using actors
actors = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=MyActor,
multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]
# if using tasks
h = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]
# note - you can also run this without Multislice
h = RayTpuManager.run_task(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]
Risorse TPU e Ray
Ray tratta le TPU in modo diverso dalle GPU per tenere conto della differenza di utilizzo. Nell'esempio seguente sono presenti in totale nove nodi Ray:
- Il nodo principale di Ray è in esecuzione su una VM
n1-standard-16
. - I nodi worker Ray sono in esecuzione su due TPU
v6e-16
. Ogni TPU è costituita da quattro worker.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
Descrizioni dei campi di utilizzo delle risorse:
CPU
: il numero totale di CPU disponibili nel cluster.TPU
: il numero di chip TPU nel cluster.TPU-v6e-16-head
: un identificatore speciale per la risorsa corrispondente al worker 0 di una sezione TPU. Questo è importante per accedere alle singole sezioni di TPU.memory
: memoria heap del worker utilizzata dalla tua applicazione.object_store_memory
: memoria utilizzata quando l'applicazione crea oggetti nell'object store utilizzandoray.put
e quando restituisce valori dalle funzioni remote.tpu-group-0
etpu-group-1
: identificatori univoci per le singole sezioni TPU. Questo è importante per l'esecuzione di job su slice. Questi campi sono impostati su 4 perché in una v6e-16 sono presenti 4 host per ogni slice TPU.
Controllare le singole sezioni TPU
Una pratica comune con Ray e le TPU è eseguire più carichi di lavoro nello stesso slice TPU, ad esempio nell'ottimizzazione degli iperparametri o nel serving.
Gli slice TPU richiedono un'attenzione particolare quando si utilizza Ray sia per il provisioning sia per la pianificazione dei job.
Esegui carichi di lavoro a slice singolo
Quando il processo Ray viene avviato sugli slice TPU (in esecuzione su ray start
), il processo rileva automaticamente le informazioni sullo slice. Ad esempio, la topologia, il numero di worker nel segmento e se il processo è in esecuzione sul worker 0.
Quando esegui ray status
su una TPU v6e-16 con il nome "my-tpu", l'output è simile al seguente:
worker 0: {"TPU-v6e-16-head": 1, "TPU": 4, "my-tpu": 1"}
worker 1-3: {"TPU": 4, "my-tpu": 1}
"TPU-v6e-16-head"
è l'etichetta della risorsa per il worker 0 del segmento.
"TPU": 4
indica che ogni worker ha 4 chip. "my-tpu"
è il nome della TPU. Puoi utilizzare questi valori per eseguire un carico di lavoro sulle TPU all'interno dello stesso slice, come nell'esempio seguente.
Supponiamo che tu voglia eseguire la seguente funzione su tutti i worker di uno slice:
@ray.remote()
def my_function():
return jax.device_count()
Devi scegliere come target il worker 0 del segmento, quindi dire al worker 0 come trasmetteremy_function
a tutti i worker del segmento:
@ray.remote(resources={"TPU-v6e-16-head": 1})
def run_on_pod(remote_fn):
tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> returns my-tpu
num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> returns 4
remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}) # required resources are {"my-tpu": 1, "TPU": 4}
return ray.get([remote_fn.remote() for _ in range(num_hosts)])
h = run_on_pod(my_function).remote() # -> returns a single remote handle
ray.get(h) # -> returns ["16"] * 4
L'esempio esegue i seguenti passaggi:
@ray.remote(resources={"TPU-v6e-16-head": 1})
: la funzionerun_on_pod
viene eseguita su un worker con l'etichetta della risorsaTPU-v6e-16-head
, che ha come target un worker 0 arbitrario.tpu_name = ray.util.accelerators.tpu.get_current_pod_name()
: recupera il nome della TPU.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count()
: restituisce il numero di worker nel segmento.remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4})
: aggiungi l'etichetta della risorsa contenente il nome della TPU e il requisito della risorsa"TPU": 4
alla funzionemy_function
.- Poiché ogni worker nel segmento TPU ha un'etichetta della risorsa personalizzata per il segmento in cui si trova, Ray pianifica il carico di lavoro solo sui worker all'interno dello stesso segmento TPU.
- Vengono inoltre riservati 4 worker TPU per la funzione remota, quindi Ray non schedulerà altri carichi di lavoro TPU su quel pod Ray.
- Poiché
run_on_pod
utilizza solo la risorsa logicaTPU-v6e-16-head
,my_function
verrà eseguito anche sul worker 0, ma in un processo diverso.
return ray.get([remote_fn.remote() for _ in range(num_hosts)])
: invoca la funzionemy_function
modificata un numero di volte pari al numero di thread e restituisce i risultati.h = run_on_pod(my_function).remote()
:run_on_pod
verrà eseguito in modo asincrono e non bloccherà il processo principale.
Scalabilità automatica delle sezioni TPU
Ray su TPU supporta la scalabilità automatica in base alla granularità di una sezione TPU. Puoi attivare questa funzionalità utilizzando la funzionalità Provisioning automatico dei nodi GKE (NAP). Puoi eseguire questa funzionalità utilizzando Ray Autoscaler e KubeRay. Il tipo di risorsa head viene utilizzato per segnalare la scalabilità automatica a Ray, ad esempio TPU-v6e-32-head
.