Scala i workload ML utilizzando Ray

Questo documento fornisce dettagli su come eseguire carichi di lavoro di machine learning (ML) con Ray e JAX su TPU.

Queste istruzioni presuppongono che tu abbia già configurato un ambiente Ray e TPU, incluso un ambiente software che includa JAX e altri pacchetti correlati. Per creare un cluster Ray TPU, segui le istruzioni riportate in Avvia Google Cloud un cluster GKE con TPU per KubeRay. Per ulteriori informazioni sull'utilizzo delle TPU con KubeRay, consulta Utilizzare le TPU con KubeRay.

Esegui un carico di lavoro JAX su una TPU con un solo host

Lo script di esempio riportato di seguito mostra come eseguire una funzione JAX su un cluster Ray con una TPU a un solo host, ad esempio una v6e-4. Se hai una TPU multi-host, questo script non risponde a causa del modello di esecuzione con più controller di JAX. Per ulteriori informazioni sull'esecuzione di Ray su una TPU multi-host, consulta Eseguire un carico di lavoro JAX su una TPU multi-host.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

h = my_function.remote()
print(ray.get(h)) # => 4

Se sei abituato a eseguire Ray con GPU, esistono alcune differenze chiave quando utilizzi le TPU:

  • Anziché impostare num_gpus, specifica TPU come risorsa personalizzata e imposta il numero di chip TPU.
  • Specifica la TPU utilizzando il numero di chip per nodo worker Ray. Ad esempio, se utilizzi una v6e-4, l'esecuzione di una funzione remota con TPU impostato su 4 consuma l'intero host TPU.
    • Questo è diverso dal modo in cui in genere vengono eseguite le GPU, con un processo per host. L'impostazione di TPU su un numero diverso da 4 non è consigliata.
    • Eccezione: se hai un v6e-8 o un v5litepod-8 a host singolo, devi impostare questo valore su 8.

Esegui un carico di lavoro JAX su una TPU multi-host

Lo script di esempio seguente mostra come eseguire una funzione JAX su un cluster Ray con una TPU multi-host. Lo script di esempio utilizza una versione v6e-16.

Se vuoi eseguire il tuo carico di lavoro su un cluster con più sezioni TPU, consulta Controllare le singole sezioni TPU.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]

Se sei abituato a eseguire Ray con GPU, esistono alcune differenze chiave quando utilizzi le TPU:

  • Simile ai carichi di lavoro di PyTorch sulle GPU:
  • A differenza dei carichi di lavoro PyTorch sulle GPU, JAX ha una vista globale dei dispositivi disponibili nel cluster.

Esegui un carico di lavoro JAX multislice

Multislice ti consente di eseguire carichi di lavoro che coprono più slice TPU all'interno di un singolo pod di TPU o in più pod sulla rete del data center.

Per comodità, puoi utilizzare il pacchetto sperimentale ray-tpu per semplificare le interazioni di Ray con i sezioni TPU. Installa ray-tpu utilizzando pip:

pip install ray-tpu

Lo script di esempio seguente mostra come utilizzare il pacchetto ray-tpu per eseguire carichi di lavoro multislice utilizzando gli attori o le attività Ray:

from ray_tpu import RayTpuManager
import jax
import ray

ray.init()

# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
    def get_devices(self):
        return jax.device_count()

# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
    return jax.device_count()

tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus) 
"""
TPU resources:
{'v6e-16': [
    RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
    RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""

# if using actors
actors = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=MyActor,
    multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]

# if using tasks
h = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]

# note - you can also run this without Multislice
h = RayTpuManager.run_task(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]

Risorse TPU e Ray

Ray tratta le TPU in modo diverso dalle GPU per tenere conto della differenza di utilizzo. Nell'esempio seguente sono presenti in totale nove nodi Ray:

  • Il nodo principale di Ray è in esecuzione su una VM n1-standard-16.
  • I nodi worker Ray sono in esecuzione su due TPU v6e-16. Ogni TPU è costituita da quattro worker.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

Descrizioni dei campi di utilizzo delle risorse:

  • CPU: il numero totale di CPU disponibili nel cluster.
  • TPU: il numero di chip TPU nel cluster.
  • TPU-v6e-16-head: un identificatore speciale per la risorsa corrispondente al worker 0 di una sezione TPU. Questo è importante per accedere alle singole sezioni di TPU.
  • memory: memoria heap del worker utilizzata dalla tua applicazione.
  • object_store_memory: memoria utilizzata quando l'applicazione crea oggetti nell'object store utilizzando ray.put e quando restituisce valori dalle funzioni remote.
  • tpu-group-0 e tpu-group-1: identificatori univoci per le singole sezioni TPU. Questo è importante per l'esecuzione di job su slice. Questi campi sono impostati su 4 perché in una v6e-16 sono presenti 4 host per ogni slice TPU.

Controllare le singole sezioni TPU

Una pratica comune con Ray e le TPU è eseguire più carichi di lavoro nello stesso slice TPU, ad esempio nell'ottimizzazione degli iperparametri o nel serving.

Gli slice TPU richiedono un'attenzione particolare quando si utilizza Ray sia per il provisioning sia per la pianificazione dei job.

Esegui carichi di lavoro a slice singolo

Quando il processo Ray viene avviato sugli slice TPU (in esecuzione su ray start), il processo rileva automaticamente le informazioni sullo slice. Ad esempio, la topologia, il numero di worker nel segmento e se il processo è in esecuzione sul worker 0.

Quando esegui ray status su una TPU v6e-16 con il nome "my-tpu", l'output è simile al seguente:

worker 0: {"TPU-v6e-16-head": 1, "TPU": 4, "my-tpu": 1"}
worker 1-3: {"TPU": 4, "my-tpu": 1}

"TPU-v6e-16-head" è l'etichetta della risorsa per il worker 0 del segmento. "TPU": 4 indica che ogni worker ha 4 chip. "my-tpu" è il nome della TPU. Puoi utilizzare questi valori per eseguire un carico di lavoro sulle TPU all'interno dello stesso slice, come nell'esempio seguente.

Supponiamo che tu voglia eseguire la seguente funzione su tutti i worker di uno slice:

@ray.remote()
def my_function():
    return jax.device_count()

Devi scegliere come target il worker 0 del segmento, quindi dire al worker 0 come trasmetteremy_function a tutti i worker del segmento:

@ray.remote(resources={"TPU-v6e-16-head": 1})
def run_on_pod(remote_fn):
    tpu_name = ray.util.accelerators.tpu.get_current_pod_name()  # -> returns my-tpu
    num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> returns 4
    remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}) # required resources are {"my-tpu": 1, "TPU": 4}
    return ray.get([remote_fn.remote() for _ in range(num_hosts)])

h = run_on_pod(my_function).remote() # -> returns a single remote handle
ray.get(h) # -> returns ["16"] * 4

L'esempio esegue i seguenti passaggi:

  • @ray.remote(resources={"TPU-v6e-16-head": 1}): la funzione run_on_pod viene eseguita su un worker con l'etichetta della risorsa TPU-v6e-16-head, che ha come target un worker 0 arbitrario.
  • tpu_name = ray.util.accelerators.tpu.get_current_pod_name(): recupera il nome della TPU.
  • num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count(): restituisce il numero di worker nel segmento.
  • remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}): aggiungi l'etichetta della risorsa contenente il nome della TPU e il requisito della risorsa "TPU": 4 alla funzione my_function.
    • Poiché ogni worker nel segmento TPU ha un'etichetta della risorsa personalizzata per il segmento in cui si trova, Ray pianifica il carico di lavoro solo sui worker all'interno dello stesso segmento TPU.
    • Vengono inoltre riservati 4 worker TPU per la funzione remota, quindi Ray non schedulerà altri carichi di lavoro TPU su quel pod Ray.
    • Poiché run_on_pod utilizza solo la risorsa logica TPU-v6e-16-head, my_function verrà eseguito anche sul worker 0, ma in un processo diverso.
  • return ray.get([remote_fn.remote() for _ in range(num_hosts)]): invoca la funzione my_function modificata un numero di volte pari al numero di thread e restituisce i risultati.
  • h = run_on_pod(my_function).remote(): run_on_pod verrà eseguito in modo asincrono e non bloccherà il processo principale.

Scalabilità automatica delle sezioni TPU

Ray su TPU supporta la scalabilità automatica in base alla granularità di una sezione TPU. Puoi attivare questa funzionalità utilizzando la funzionalità Provisioning automatico dei nodi GKE (NAP). Puoi eseguire questa funzionalità utilizzando Ray Autoscaler e KubeRay. Il tipo di risorsa head viene utilizzato per segnalare la scalabilità automatica a Ray, ad esempio TPU-v6e-32-head.