ML-Arbeitslasten mit Ray skalieren

In diesem Dokument finden Sie Details zum Ausführen von ML-Arbeitslasten mit Ray und JAX auf TPUs.

In dieser Anleitung wird davon ausgegangen, dass Sie bereits eine Ray- und TPU-Umgebung eingerichtet haben, einschließlich einer Softwareumgebung mit JAX und anderen zugehörigen Paketen. Folgen Sie der Anleitung unter GKE-Cluster mit TPUs für KubeRay startenGoogle Cloud , um einen Ray-TPU-Cluster zu erstellen. Weitere Informationen zur Verwendung von TPUs mit KubeRay finden Sie unter TPUs mit KubeRay verwenden.

JAX-Arbeitslast auf einer TPU mit einem einzelnen Host ausführen

Das folgende Beispielskript zeigt, wie eine JAX-Funktion in einem Ray-Cluster mit einer TPU mit einem einzelnen Host wie einer v6e-4 ausgeführt wird. Wenn Sie eine TPU mit mehreren Hosts haben, reagiert dieses Script aufgrund des Multi-Controller-Ausführungsmodells von JAX nicht mehr. Weitere Informationen zum Ausführen von Ray auf einer TPU mit mehreren Hosts finden Sie unter JAX-Arbeitslast auf einer TPU mit mehreren Hosts ausführen.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

h = my_function.remote()
print(ray.get(h)) # => 4

Wenn Sie Ray bisher mit GPUs ausgeführt haben, gibt es einige wichtige Unterschiede bei der Verwendung von TPUs:

  • Anstatt num_gpus festzulegen, geben Sie TPU als benutzerdefinierte Ressource an und legen die Anzahl der TPU-Chips fest.
  • Sie geben die TPU anhand der Anzahl der Chips pro Ray-Workerknoten an. Wenn Sie beispielsweise v6e-4 verwenden, wird beim Ausführen einer Remotefunktion mit TPU = 4 der gesamte TPU-Host belegt.
    • Das unterscheidet sich von der üblichen Ausführung von GPUs mit einem Prozess pro Host. Es wird nicht empfohlen, TPU auf eine andere Zahl als 4 festzulegen.
    • Ausnahme: Wenn Sie v6e-8 oder v5litepod-8 für einen einzelnen Host verwenden, sollten Sie diesen Wert auf 8 festlegen.

JAX-Arbeitslast auf einer TPU mit mehreren Hosts ausführen

Das folgende Beispielscript zeigt, wie eine JAX-Funktion in einem Ray-Cluster mit einer TPU mit mehreren Hosts ausgeführt wird. Im Beispielscript wird eine v6e-16 verwendet.

Wenn Sie Ihre Arbeitslast in einem Cluster mit mehreren TPU-Slices ausführen möchten, lesen Sie den Hilfeartikel Einzelne TPU-Slices steuern.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]

Wenn Sie Ray bisher mit GPUs ausgeführt haben, gibt es einige wichtige Unterschiede bei der Verwendung von TPUs:

  • Ähnlich wie bei PyTorch-Arbeitslasten auf GPUs:
  • Im Gegensatz zu PyTorch-Arbeitslasten auf GPUs hat JAX eine globale Ansicht der verfügbaren Geräte im Cluster.

Multislice-JAX-Arbeitslast ausführen

Mit Multislice können Sie Arbeitslasten ausführen, die mehrere TPU-Slices innerhalb eines einzelnen TPU-Pod oder in mehreren Pods über das Rechenzentrumsnetzwerk umfassen.

Sie können das experimentelle Paket ray-tpu verwenden, um die Interaktionen von Ray mit TPU-Scheiben zu vereinfachen. Installieren Sie ray-tpu mit pip:

pip install ray-tpu

Das folgende Beispielskript zeigt, wie Sie mit dem Paket ray-tpu Multislice-Arbeitslasten mit Ray-Actors oder ‑Tasks ausführen:

from ray_tpu import RayTpuManager
import jax
import ray

ray.init()

# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
    def get_devices(self):
        return jax.device_count()

# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
    return jax.device_count()

tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus) 
"""
TPU resources:
{'v6e-16': [
    RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
    RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""

# if using actors
actors = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=MyActor,
    multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]

# if using tasks
h = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]

# note - you can also run this without Multislice
h = RayTpuManager.run_task(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]

TPU- und Ray-Ressourcen

Ray behandelt TPUs anders als GPUs, um die unterschiedlichen Nutzungsanforderungen zu berücksichtigen. Im folgenden Beispiel gibt es insgesamt neun Ray-Knoten:

  • Der Ray-Leitknoten wird auf einer n1-standard-16-VM ausgeführt.
  • Die Ray-Worker-Knoten werden auf zwei v6e-16 TPUs ausgeführt. Jede TPU besteht aus vier Workern.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

Beschreibungen der Felder zur Ressourcennutzung:

  • CPU: Die Gesamtzahl der im Cluster verfügbaren CPUs.
  • TPU: Die Anzahl der TPU-Chips im Cluster.
  • TPU-v6e-16-head: Eine spezielle Kennung für die Ressource, die Worker 0 eines TPU-Slabs entspricht. Das ist wichtig für den Zugriff auf einzelne TPU-Scheiben.
  • memory: Der von Ihrer Anwendung verwendete Worker-Heap-Speicher.
  • object_store_memory: Speicher, der verwendet wird, wenn Ihre Anwendung Objekte im Objektspeicher mit ray.put erstellt und Werte von Remotefunktionen zurückgibt.
  • tpu-group-0 und tpu-group-1: Eindeutige Kennungen für die einzelnen TPU-Scheiben. Das ist wichtig, wenn Jobs auf Segmenten ausgeführt werden sollen. Diese Felder sind auf „4“ festgelegt, da es in einem v6e-16 vier Hosts pro TPU-Speichere gibt.

Einzelne TPU-Slices steuern

Bei Ray und TPUs ist es üblich, mehrere Arbeitslasten im selben TPU-Speicherbereich auszuführen, z. B. bei der Hyperparameter-Abstimmung oder beim Bereitstellen.

TPU-Scheiben müssen bei der Verwendung von Ray sowohl für die Bereitstellung als auch für die Jobplanung besonders berücksichtigt werden.

Arbeitslasten mit einem einzelnen Slice ausführen

Wenn der Ray-Prozess auf TPU-Scheiben gestartet wird (ray start wird ausgeführt), werden Informationen zur Scheibe automatisch erkannt. Dazu gehören beispielsweise die Topologie, die Anzahl der Worker im Slice und ob der Prozess auf Worker 0 ausgeführt wird.

Wenn Sie ray status auf einer TPU v6e-16 mit dem Namen „my-tpu“ ausführen, sieht die Ausgabe in etwa so aus:

worker 0: {"TPU-v6e-16-head": 1, "TPU": 4, "my-tpu": 1"}
worker 1-3: {"TPU": 4, "my-tpu": 1}

"TPU-v6e-16-head" ist das Ressourcenlabel für Worker 0 des Segments. "TPU": 4 gibt an, dass jeder Worker 4 Chips hat. "my-tpu" ist der Name der TPU. Sie können diese Werte verwenden, um eine Arbeitslast auf TPUs innerhalb desselben Segments auszuführen, wie im folgenden Beispiel.

Angenommen, Sie möchten die folgende Funktion auf allen Workern in einem Sliver ausführen:

@ray.remote()
def my_function():
    return jax.device_count()

Sie müssen Worker 0 des Slices als Ziel auswählen und ihm dann mitteilen, wie er my_function an alle Worker im Slice senden soll:

@ray.remote(resources={"TPU-v6e-16-head": 1})
def run_on_pod(remote_fn):
    tpu_name = ray.util.accelerators.tpu.get_current_pod_name()  # -> returns my-tpu
    num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count() # -> returns 4
    remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}) # required resources are {"my-tpu": 1, "TPU": 4}
    return ray.get([remote_fn.remote() for _ in range(num_hosts)])

h = run_on_pod(my_function).remote() # -> returns a single remote handle
ray.get(h) # -> returns ["16"] * 4

Im Beispiel werden die folgenden Schritte ausgeführt:

  • @ray.remote(resources={"TPU-v6e-16-head": 1}): Die Funktion run_on_pod wird auf einem Worker mit dem Ressourcenlabel TPU-v6e-16-head ausgeführt, das auf einen beliebigen Worker 0 ausgerichtet ist.
  • tpu_name = ray.util.accelerators.tpu.get_current_pod_name(): Ruft den TPU-Namen ab.
  • num_hosts = ray.util.accelerators.tpu.get_current_pod_worker_count(): Die Anzahl der Worker im Ausschnitt abrufen.
  • remote_fn = remote_fn.options(resources={tpu_name: 1, "TPU": 4}): Fügen Sie der Funktion my_function das Ressourcenlabel mit dem TPU-Namen und der "TPU": 4-Ressourcenanforderungen hinzu.
    • Da jeder Worker im TPU-Speicherbereich ein benutzerdefiniertes Ressourcenlabel für den Speicherbereich hat, in dem er sich befindet, plant Ray die Arbeitslast nur auf den Workern innerhalb desselben TPU-Speicherbereichs.
    • Dadurch werden auch vier TPU-Worker für die Remote-Funktion reserviert, sodass Ray keine anderen TPU-Arbeitslasten auf diesem Ray-Pod plant.
    • Da run_on_pod nur die logische Ressource TPU-v6e-16-head verwendet, wird my_function auch auf Worker 0 ausgeführt, aber in einem anderen Prozess.
  • return ray.get([remote_fn.remote() for _ in range(num_hosts)]): Die geänderte my_function-Funktion wird so oft aufgerufen wie es Arbeitskräfte gibt und die Ergebnisse werden zurückgegeben.
  • h = run_on_pod(my_function).remote(): run_on_pod wird asynchron ausgeführt und blockiert den Hauptprozess nicht.

Autoscaling von TPU-Slices

Ray auf TPUs unterstützt das Autoscaling auf Ebene eines TPU-Slabs. Sie können diese Funktion mit der automatischen Bereitstellung von GKE-Knoten (GKE Node Auto Provisioning, NAP) aktivieren. Sie können diese Funktion mit dem Ray-Autoscaler und KubeRay ausführen. Mit dem Ressourcentyp „head“ wird Ray das Autoscaling signalisiert, z. B. TPU-v6e-32-head.