Escalonar cargas de trabalho de ML usando o Ray
Este documento fornece detalhes sobre como executar cargas de trabalho de machine learning (ML) com o Ray e o JAX em TPUs.
Estas instruções pressupõem que você já tenha configurado um ambiente Ray e TPU, incluindo um ambiente de software que inclui JAX e outros pacotes relacionados. Para criar um cluster do Ray TPU, siga as instruções em Iniciar Google Cloud cluster do GKE com TPUs para KubeRay. Para mais informações sobre o uso de TPUs com o KubeRay, consulte Usar TPUs com o KubeRay.
Executar uma carga de trabalho do JAX em uma TPU de host único
O exemplo de script abaixo demonstra como executar uma função JAX em um cluster do Ray com um TPU de host único, como um v6e-4. Se você tiver uma TPU de vários hosts, esse script vai parar de responder devido ao modelo de execução de vários controladores do JAX. Para mais informações sobre como executar o Ray em uma TPU de vários hosts, consulte Executar uma carga de trabalho do JAX em uma TPU de vários hosts.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
h = my_function.remote()
print(ray.get(h)) # => 4
Se você está acostumado a executar o Ray com GPUs, há algumas diferenças importantes ao usar TPUs:
- Em vez de definir
num_gpus
, especifiqueTPU
como um recurso personalizado e defina o número de chips de TPU. - Especifique a TPU usando o número de chips por nó de worker do Ray. Por
exemplo, se você estiver usando uma v6e-4, a execução de uma função remota com
TPU
definido como 4 consome todo o host TPU.- Isso é diferente da forma como as GPUs normalmente são executadas, com um processo por host.
Não é recomendável definir
TPU
como um número diferente de 4. - Exceção: se você tiver um
v6e-8
ouv5litepod-8
de host único, defina esse valor como 8.
- Isso é diferente da forma como as GPUs normalmente são executadas, com um processo por host.
Não é recomendável definir
Executar uma carga de trabalho do JAX em uma TPU de vários hosts
O exemplo de script abaixo demonstra como executar uma função JAX em um cluster do Ray com uma TPU de vários hosts. O script de exemplo usa uma v6e-16.
Se você quiser executar a carga de trabalho em um cluster com várias frações de TPU, consulte Controlar frações de TPU individuais.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]
Se você está acostumado a executar o Ray com GPUs, há algumas diferenças importantes ao usar TPUs:
- Semelhante às cargas de trabalho do PyTorch em GPUs:
- As cargas de trabalho do JAX em TPUs são executadas de uma forma de vários controladores, único programa, vários dados (SPMD, na sigla em inglês).
- Os coletivos entre dispositivos são processados pelo framework de machine learning.
- Ao contrário das cargas de trabalho do PyTorch em GPUs, o JAX tem uma visão global dos dispositivos disponíveis no cluster.
Executar uma carga de trabalho JAX Multislice
O Multislice permite executar cargas de trabalho que abrangem várias frações de TPU em um único pod de TPU ou em vários pods na rede do data center.
Para sua conveniência, use o pacote experimental
ray-tpu
para simplificar
as interações do Ray com fatias de TPU. Instale ray-tpu
usando pip
:
pip install ray-tpu
O exemplo de script a seguir mostra como usar o pacote ray-tpu
para executar
cargas de trabalho de várias fatias usando atores ou tarefas do Ray:
from ray_tpu import RayTpuManager
import jax
import ray
ray.init()
# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
def get_devices(self):
return jax.device_count()
# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
return jax.device_count()
tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus)
"""
TPU resources:
{'v6e-16': [
RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""
# if using actors
actors = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=MyActor,
multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]
# if using tasks
h = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]
# note - you can also run this without Multislice
h = RayTpuManager.run_task(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]
Recursos de TPU e Ray
O Ray trata as TPUs de maneira diferente das GPUs para acomodar a diferença no uso. No exemplo a seguir, há nove nós Ray no total:
- O nó principal do Ray está em execução em uma VM
n1-standard-16
. - Os nós de trabalho do Ray estão sendo executados em duas TPUs
v6e-16
. Cada TPU constitui quatro workers.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
Descrições dos campos de uso de recursos:
CPU
: o número total de CPUs disponíveis no cluster.TPU
: o número de chips de TPU no cluster.TPU-v6e-16-head
: um identificador especial para o recurso que corresponde ao worker 0 de uma fração de TPU. Isso é importante para acessar fatias de TPU individuais.memory
: memória heap do worker usada pelo aplicativo.object_store_memory
: memória usada quando o aplicativo cria objetos no armazenamento de objetos usandoray.put
e quando ele retorna valores de funções remotas.tpu-group-0
etpu-group-1
: identificadores exclusivos para os segmentos de TPU individuais. Isso é importante para executar trabalhos em fatias. Esses campos são definidos como 4 porque há 4 hosts por fatia de TPU em uma v6e-16.
Controlar frações de TPU individuais
Uma prática comum com o Ray e as TPUs é executar vários workloads na mesma fração de TPU, por exemplo, no ajuste ou fornecimento de hiperparâmetros.
As fatias de TPU exigem consideração especial ao usar o Ray para provisionamento e programação de jobs.
Executar cargas de trabalho de fatia única
Quando o processo Ray é iniciado em fatias de TPU (executando ray start
), ele
detecta automaticamente informações sobre a fatia. Por exemplo, a topologia, o número de workers na fração e se o processo está sendo executado no worker 0.
Quando você executa ray status
em um TPU v6e-16 com o nome "my-tpu", a saída
é semelhante a esta:
worker 0: {"TPU-v6e-16-head": 1, "TPU": 4, "my-tpu": 1"}
worker 1-3: {"TPU": 4, "my-tpu": 1}
"TPU-v6e-16-head"
é o rótulo de recurso do worker 0 da fatia.
"TPU": 4
indica que cada worker
tem quatro chips. "my-tpu"
é o nome da TPU. É possível usar esses valores para executar
uma carga de trabalho em TPUs na mesma fatia, como no exemplo a seguir.
Suponha que você queira executar a função a seguir em todos os workers de uma fatia:
@ray.remote()
def my_function():
return jax.device_count()
Você precisa segmentar o worker 0 da fração e informar a ele como transmitir
my_function
para todos os workers na fração:
@ray.remote(resources={"TPU-v6e-16-head": 1})
def run_on_pod(remote_fn):
tpu_name = ray.util.accelerators.tpu.get_current_pod_name() # -> returns my-tpu
num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> returns 4
remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}) # required resources are {"my-tpu": 1, "TPU": 4}
return ray.get([remote_fn.remote() for _ in range(num_hosts)])
h = run_on_pod(my_function).remote() # -> returns a single remote handle
ray.get(h) # -> returns ["16"] * 4
O exemplo executa as seguintes etapas:
@ray.remote(resources={"TPU-v6e-16-head": 1})
: a funçãorun_on_pod
é executada em um worker que tem o rótulo de recursoTPU-v6e-16-head
, que é direcionado a qualquer worker arbitrário 0.tpu_name = ray.util.accelerators.tpu.get_current_pod_name()
: extrai o nome da TPU.num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count()
: extrai o número de workers na fatia.remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4})
: adicione o rótulo de recurso que contém o nome da TPU e o requisito de recurso"TPU": 4
à funçãomy_function
.- Como cada worker na fração da TPU tem um rótulo de recurso personalizado para a fração em que está, o Ray só vai programar a carga de trabalho nos workers dentro da mesma fração da TPU.
- Isso também reserva quatro workers de TPU para a função remota, para que o Ray não programe outras cargas de trabalho de TPU nesse pod do Ray.
- Como
run_on_pod
usa apenas o recurso lógicoTPU-v6e-16-head
,my_function
também será executado no worker 0, mas em um processo diferente.
return ray.get([remote_fn.remote() for _ in range(num_hosts)])
: invoca a funçãomy_function
modificada um número de vezes igual ao número de workers e retorna os resultados.h = run_on_pod(my_function).remote()
:run_on_pod
será executado de maneira assíncrona e não bloqueará o processo principal.
Escalonamento automático de fatias de TPU
O Ray em TPUs oferece suporte ao escalonamento automático na granularidade de uma fração de TPU. É possível
ativar esse recurso usando o provisionamento automático de nós do GKE
(NAP). É possível
executar esse recurso usando o escalonador automático do Ray e o KubeRay. O tipo de recurso principal
é usado para sinalizar o escalonamento automático para o Ray, por exemplo, TPU-v6e-32-head
.