Évoluer les charges de travail de ML à l'aide de Ray

Ce document explique comment exécuter des charges de travail de machine learning (ML) avec Ray et JAX sur des TPU.

Ces instructions supposent que vous avez déjà configuré un environnement Ray et TPU, y compris un environnement logiciel qui inclut JAX et d'autres packages associés. Pour créer un cluster TPU Ray, suivez les instructions de la section Démarrer un clusterGoogle Cloud GKE avec des TPU pour KubeRay. Pour en savoir plus sur l'utilisation des TPU avec KubeRay, consultez Utiliser des TPU avec KubeRay.

Exécuter une charge de travail JAX sur un TPU à hôte unique

L'exemple de script suivant montre comment exécuter une fonction JAX sur un cluster Ray avec un TPU à hôte unique, tel qu'un v6e-4. Si vous disposez d'un TPU multi-hôte, ce script cesse de répondre en raison du modèle d'exécution multicontrôleur de JAX. Pour en savoir plus sur l'exécution de Ray sur un TPU multi-hôte, consultez Exécuter une charge de travail JAX sur un TPU multi-hôte.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

h = my_function.remote()
print(ray.get(h)) # => 4

Si vous avez l'habitude d'exécuter Ray avec des GPU, vous constaterez quelques différences clés lorsque vous utiliserez des TPU:

  • Au lieu de définir num_gpus, vous spécifiez TPU en tant que ressource personnalisée et définissez le nombre de puces TPU.
  • Vous spécifiez le TPU à l'aide du nombre de puces par nœud de travail Ray. Par exemple, si vous utilisez un v6e-4, l'exécution d'une fonction distante avec TPU défini sur 4 consomme l'ensemble de l'hôte TPU.
    • Cela diffère de la manière dont les GPU s'exécutent généralement, avec un processus par hôte. Il est déconseillé de définir TPU sur un nombre autre que 4.
    • Exception: Si vous disposez d'un v6e-8 ou d'un v5litepod-8 à hôte unique, vous devez définir cette valeur sur 8.

Exécuter une charge de travail JAX sur un TPU multi-hôte

L'exemple de script suivant montre comment exécuter une fonction JAX sur un cluster Ray avec un TPU multi-hôte. L'exemple de script utilise une version v6e-16.

Si vous souhaitez exécuter votre charge de travail sur un cluster avec plusieurs tranches de TPU, consultez la section Contrôler des tranches de TPU individuelles.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]

Si vous avez l'habitude d'exécuter Ray avec des GPU, vous constaterez quelques différences clés lorsque vous utiliserez des TPU:

  • Comme pour les charges de travail PyTorch sur les GPU :
  • Contrairement aux charges de travail PyTorch sur les GPU, JAX dispose d'une vue globale des appareils disponibles dans le cluster.

Exécuter une charge de travail JAX Multislice

Multislice vous permet d'exécuter des charges de travail couvrant plusieurs tranches de TPU dans un seul pod TPU ou dans plusieurs pods sur le réseau du centre de données.

Pour plus de commodité, vous pouvez utiliser le package expérimental ray-tpu pour simplifier les interactions de Ray avec les tranches TPU. Installez ray-tpu à l'aide de pip:

pip install ray-tpu

L'exemple de script suivant montre comment utiliser le package ray-tpu pour exécuter des charges de travail multicouches à l'aide d'acteurs ou de tâches Ray:

from ray_tpu import RayTpuManager
import jax
import ray

ray.init()

# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
    def get_devices(self):
        return jax.device_count()

# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
    return jax.device_count()

tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus) 
"""
TPU resources:
{'v6e-16': [
    RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
    RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""

# if using actors
actors = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=MyActor,
    multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]

# if using tasks
h = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]

# note - you can also run this without Multislice
h = RayTpuManager.run_task(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]

Ressources TPU et Ray

Ray traite les TPU différemment des GPU pour tenir compte de la différence d'utilisation. Dans l'exemple suivant, il y a neuf nœuds Ray au total:

  • Le nœud principal Ray s'exécute sur une VM n1-standard-16.
  • Les nœuds de calcul Ray s'exécutent sur deux TPU v6e-16. Chaque TPU constitue quatre nœuds de calcul.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

Descriptions des champs d'utilisation des ressources:

  • CPU: nombre total de processeurs disponibles dans le cluster.
  • TPU: nombre de puces TPU dans le cluster.
  • TPU-v6e-16-head: identifiant spécial de la ressource correspondant au nœud de calcul 0 d'une tranche TPU. Cela est important pour accéder à des tranches de TPU individuelles.
  • memory: mémoire de tas de nœuds de calcul utilisée par votre application.
  • object_store_memory: mémoire utilisée lorsque votre application crée des objets dans le magasin d'objets à l'aide de ray.put et lorsqu'elle renvoie des valeurs à partir de fonctions distantes.
  • tpu-group-0 et tpu-group-1: identifiants uniques des tranches de TPU individuelles. Cela est important pour exécuter des tâches sur des tranches. Ces champs sont définis sur 4, car il y a quatre hôtes par tranche de TPU dans un v6e-16.

Contrôler des tranches de TPU individuelles

Une pratique courante avec Ray et les TPU consiste à exécuter plusieurs charges de travail dans la même tranche TPU, par exemple pour le réglage des hyperparamètres ou le traitement.

Les tranches TPU nécessitent une attention particulière lorsque vous utilisez Ray à la fois pour le provisionnement et la planification des tâches.

Exécuter des charges de travail à tranche unique

Lorsque le processus Ray démarre sur des tranches TPU (exécution de ray start), il détecte automatiquement des informations sur la tranche. Par exemple, la topologie, le nombre de nœuds de calcul dans la tranche et si le processus s'exécute sur le nœud de calcul 0.

Lorsque vous exécutez ray status sur un TPU v6e-16 nommé "my-tpu", le résultat ressemble à ce qui suit:

worker 0: {"TPU-v6e-16-head": 1, "TPU": 4, "my-tpu": 1"}
worker 1-3: {"TPU": 4, "my-tpu": 1}

"TPU-v6e-16-head" est le libellé de ressource du nœud de calcul 0 de la tranche. "TPU": 4 indique que chaque nœud de calcul dispose de quatre puces. "my-tpu" correspond au nom du TPU. Vous pouvez utiliser ces valeurs pour exécuter une charge de travail sur des TPU dans le même segment, comme dans l'exemple suivant.

Supposons que vous souhaitiez exécuter la fonction suivante sur tous les nœuds de calcul d'une tranche:

@ray.remote()
def my_function():
    return jax.device_count()

Vous devez cibler le nœud de calcul 0 de la tranche, puis indiquer au nœud de calcul 0 comment diffuser my_function à tous les nœuds de calcul de la tranche:

@ray.remote(resources={"TPU-v6e-16-head": 1})
def run_on_pod(remote_fn):
    tpu_name = ray.util.accelerators.tpu.get_current_pod_name()  # -> returns my-tpu
    num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> returns 4
    remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}) # required resources are {"my-tpu": 1, "TPU": 4}
    return ray.get([remote_fn.remote() for _ in range(num_hosts)])

h = run_on_pod(my_function).remote() # -> returns a single remote handle
ray.get(h) # -> returns ["16"] * 4

L'exemple effectue les étapes suivantes:

  • @ray.remote(resources={"TPU-v6e-16-head": 1}): la fonction run_on_pod s'exécute sur un nœud de calcul associé au libellé de ressource TPU-v6e-16-head, qui cible n'importe quel nœud de calcul arbitraire 0.
  • tpu_name = ray.util.accelerators.tpu.get_current_pod_name(): permet d'obtenir le nom du TPU.
  • num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() : obtient le nombre de nœuds de calcul dans la tranche.
  • remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}): ajoutez le libellé de ressource contenant le nom du TPU et l'exigence de ressource "TPU": 4 à la fonction my_function.
    • Étant donné que chaque nœud de calcul de la tranche de TPU possède un libellé de ressource personnalisé pour la tranche dans laquelle il se trouve, Ray ne planifie la charge de travail que sur les nœuds de calcul de la même tranche de TPU.
    • Cela réserve également quatre nœuds de calcul TPU pour la fonction distante. Ray ne planifiera donc pas d'autres charges de travail TPU sur ce pod Ray.
    • Étant donné que run_on_pod n'utilise que la ressource logique TPU-v6e-16-head, my_function s'exécute également sur le nœud de travail 0, mais dans un processus différent.
  • return ray.get([remote_fn.remote() for _ in range(num_hosts)]): appelle la fonction my_function modifiée un nombre de fois égal au nombre de workers et renvoie les résultats.
  • h = run_on_pod(my_function).remote(): run_on_pod s'exécute de manière asynchrone et ne bloque pas le processus principal.

Autoscaling des tranches TPU

Ray sur les TPU est compatible avec l'autoscaling à la granularité d'une tranche TPU. Vous pouvez activer cette fonctionnalité à l'aide de la fonctionnalité de provisionnement automatique des nœuds GKE (NAP). Vous pouvez exécuter cette fonctionnalité à l'aide de l'autoscaler Ray et de KubeRay. Le type de ressource de tête permet de signaler l'autoscaling à Ray, par exemple TPU-v6e-32-head.