Escala las cargas de trabajo de AA con Ray
En este documento, se proporcionan detalles sobre cómo ejecutar cargas de trabajo de aprendizaje automático (AA) con Ray y JAX en TPU.
En estas instrucciones, se supone que ya tienes configurado un entorno de Ray y TPU, incluido un entorno de software que incluye JAX y otros paquetes relacionados. Para crear un clúster de Ray TPU, sigue las instrucciones que se indican en Cómo iniciar unGoogle Cloud clúster de GKE con TPU para KubeRay. Para obtener más información sobre el uso de TPU con KubeRay, consulta Cómo usar TPU con KubeRay.
Ejecuta una carga de trabajo de JAX en una TPU de host único
En la siguiente secuencia de comandos de ejemplo, se muestra cómo ejecutar una función de JAX en un clúster de Ray con una TPU de un solo host, como una v6e-4. Si tienes una TPU de varios hosts, esta secuencia de comandos deja de responder debido al modelo de ejecución de varios controladores de JAX. Para obtener más información sobre cómo ejecutar Ray en una TPU de varios hosts, consulta Cómo ejecutar una carga de trabajo de JAX en una TPU de varios hosts.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
h = my_function.remote()
print(ray.get(h)) # => 4
Si estás acostumbrado a ejecutar Ray con GPUs, existen algunas diferencias clave cuando se usan TPU:
- En lugar de configurar
num_gpus
, especificasTPU
como un recurso personalizado y estableces la cantidad de chips TPU. - Especificas la TPU con la cantidad de chips por nodo trabajador de Ray. Por ejemplo, si usas una v6e-4, ejecutar una función remota con
TPU
establecida en 4 consume todo el host de TPU.- Esto es diferente de la forma en que se ejecutan las GPU, con un proceso por host.
No se recomienda establecer
TPU
en un número que no sea 4. - Excepción: Si tienes un
v6e-8
ov5litepod-8
de un solo host, debes establecer este valor en 8.
- Esto es diferente de la forma en que se ejecutan las GPU, con un proceso por host.
No se recomienda establecer
Ejecuta una carga de trabajo de JAX en una TPU de varios hosts
En la siguiente secuencia de comandos de ejemplo, se muestra cómo ejecutar una función de JAX en un clúster de Ray con una TPU de varios hosts. La secuencia de comandos de ejemplo usa una v6e-16.
Si deseas ejecutar tu carga de trabajo en un clúster con varias porciones de TPU, consulta Controla porciones de TPU individuales.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]
Si estás acostumbrado a ejecutar Ray con GPUs, existen algunas diferencias clave cuando se usan TPU:
- Similar a las cargas de trabajo de PyTorch en GPUs:
- Las cargas de trabajo de JAX en TPUs se ejecutan en un modo de varios controladores, un solo programa y varios datos (SPMD).
- El framework de aprendizaje automático controla los colectivos entre dispositivos.
- A diferencia de las cargas de trabajo de PyTorch en GPUs, JAX tiene una vista global de los dispositivos disponibles en el clúster.
Ejecuta una carga de trabajo de JAX de Multislice
Multislice te permite ejecutar cargas de trabajo que abarcan varias porciones de TPU dentro de un solo pod de TPU o en varios pods a través de la red del centro de datos.
Para mayor comodidad, puedes usar el paquete experimental ray-tpu
para simplificar las interacciones de Ray con las porciones de TPU. Instala ray-tpu
con pip
:
pip install ray-tpu
En la siguiente secuencia de comandos de ejemplo, se muestra cómo usar el paquete ray-tpu
para ejecutar cargas de trabajo de Multislice con actores o tareas de Ray:
from ray_tpu import RayTpuManager
import jax
import ray
ray.init()
# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
def get_devices(self):
return jax.device_count()
# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
return jax.device_count()
tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus)
"""
TPU resources:
{'v6e-16': [
RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""
# if using actors
actors = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=MyActor,
multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]
# if using tasks
h = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]
# note - you can also run this without Multislice
h = RayTpuManager.run_task(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]
Recursos de TPU y Ray
Ray trata las TPU de manera diferente a las GPUs para adaptarse a la diferencia de uso. En el siguiente ejemplo, hay nueve nodos de Ray en total:
- El nodo principal de Ray se ejecuta en una VM
n1-standard-16
. - Los nodos de trabajo de Ray se ejecutan en dos TPU
v6e-16
. Cada TPU constituye cuatro trabajadores.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
Descripciones de los campos de uso de recursos:
CPU
: Es la cantidad total de CPUs disponibles en el clúster.TPU
: Es la cantidad de chips TPU en el clúster.TPU-v6e-16-head
: Es un identificador especial para el recurso que corresponde al trabajador 0 de una porción de TPU. Esto es importante para acceder a las porciones individuales de TPU.memory
: Es la memoria del montón de trabajo que usa tu aplicación.object_store_memory
: Es la memoria que se usa cuando tu aplicación crea objetos en el almacén de objetos conray.put
y cuando muestra valores de funciones remotas.tpu-group-0
ytpu-group-1
: Son identificadores únicos para las porciones individuales de TPU. Esto es importante para ejecutar trabajos en rebanadas. Estos campos se establecen en 4 porque hay 4 hosts por porción de TPU en una v6e-16.
Controla porciones de TPU individuales
Una práctica común con Ray y las TPU es ejecutar varias cargas de trabajo dentro de la misma porción de TPU, por ejemplo, en el ajuste de hiperparámetros o la entrega.
Las porciones de TPU requieren una consideración especial cuando se usa Ray para el aprovisionamiento y la programación de trabajos.
Ejecuta cargas de trabajo de una sola porción
Cuando se inicia el proceso de Ray en las porciones de TPU (se ejecuta ray start
), el proceso detecta automáticamente la información sobre la porción. Por ejemplo, la topología, la cantidad de trabajadores en la porción y si el proceso se ejecuta en el trabajador 0.
Cuando ejecutas ray status
en una TPU v6e-16 con el nombre "my-tpu", el resultado se ve similar al siguiente:
worker 0: {"TPU-v6e-16-head": 1, "TPU": 4, "my-tpu": 1"}
worker 1-3: {"TPU": 4, "my-tpu": 1}
"TPU-v6e-16-head"
es la etiqueta de recurso del trabajador 0 de la porción.
"TPU": 4
indica que cada trabajador tiene 4 chips. "my-tpu"
es el nombre de la TPU. Puedes usar estos valores para ejecutar una carga de trabajo en TPU dentro de la misma porción, como en el siguiente ejemplo.
Supongamos que deseas ejecutar la siguiente función en todos los trabajadores de una porción:
@ray.remote()
def my_function():
return jax.device_count()
Debes orientar el trabajador 0 de la porción y, luego, decirle al trabajador 0 cómo transmitir my_function
a todos los trabajadores de la porción:
@ray.remote(resources={"TPU-v6e-16-head": 1})
def run_on_pod(remote_fn):
tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> returns my-tpu
num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> returns 4
remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}) # required resources are {"my-tpu": 1, "TPU": 4}
return ray.get([remote_fn.remote() for _ in range(num_hosts)])
h = run_on_pod(my_function).remote() # -> returns a single remote handle
ray.get(h) # -> returns ["16"] * 4
En el ejemplo, se realizan los siguientes pasos:
@ray.remote(resources={"TPU-v6e-16-head": 1})
: La funciónrun_on_pod
se ejecuta en un trabajador que tiene la etiqueta de recursoTPU-v6e-16-head
, que se orienta a cualquier trabajador arbitrario 0.tpu_name = ray.util.accelerators.tpu.get_current_pod_name()
: Obtén el nombre de la TPU.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count()
: Obtén la cantidad de trabajadores en la porción.remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4})
: Agrega la etiqueta de recurso que contiene el nombre de la TPU y el requisito de recursos"TPU": 4
a la funciónmy_function
.- Debido a que cada trabajador en la porción de TPU tiene una etiqueta de recurso personalizada para la porción en la que se encuentra, Ray solo programará la carga de trabajo en los trabajadores dentro de la misma porción de TPU.
- Esto también reserva 4 trabajadores de TPU para la función remota, por lo que Ray no programará otras cargas de trabajo de TPU en ese Pod de Ray.
- Debido a que
run_on_pod
solo usa el recurso lógicoTPU-v6e-16-head
,my_function
también se ejecutará en el trabajador 0, pero en un proceso diferente.
return ray.get([remote_fn.remote() for _ in range(num_hosts)])
: Invoca la funciónmy_function
modificada una cantidad de veces igual a la cantidad de trabajadores y muestra los resultados.h = run_on_pod(my_function).remote()
:run_on_pod
se ejecutará de forma asíncrona y no bloqueará el proceso principal.
Ajuste de escala automático de porciones de TPU
Ray en TPU admite el ajuste de escala automático en el nivel de detalle de una porción de TPU. Puedes habilitar esta función con la función de aprovisionamiento automático de nodos de GKE (NAP). Puedes ejecutar esta función con el escalador automático de Ray y KubeRay. El tipo de recurso principal se usa para indicar el ajuste de escala automático a Ray, por ejemplo, TPU-v6e-32-head
.