Escala las cargas de trabajo de AA con Ray

En este documento, se proporcionan detalles sobre cómo ejecutar cargas de trabajo de aprendizaje automático (AA) con Ray y JAX en TPU. Hay dos modos diferentes para usar TPU con Ray: el modo centrado en el dispositivo (PyTorch/XLA) y el modo centrado en el host (JAX).

En este documento, se supone que ya tienes configurado un entorno de TPU. Para obtener más información, consulta los siguientes recursos:

Modo centrado en el dispositivo (PyTorch/XLA)

El modo centrado en el dispositivo conserva gran parte del estilo programático de PyTorch clásico. En este modo, agregas un nuevo tipo de dispositivo XLA, que funciona como cualquier otro dispositivo PyTorch. Cada proceso individual interactúa con un dispositivo XLA.

Este modo es ideal si ya conoces PyTorch con GPUs y quieres usar abstracciones de programación similares.

En las siguientes secciones, se describe cómo ejecutar una carga de trabajo de PyTorch/XLA en uno o más dispositivos sin usar Ray y, luego, cómo ejecutar la misma carga de trabajo en varios hosts con Ray.

Crear una TPU

  1. Crea variables de entorno para los parámetros de creación de TPU:

    export TPU_NAME=TPU_NAME
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-8
    export VERSION=v2-alpha-tpuv5

    Descripciones de las variables de entorno

    TPU_NAME
    El nombre de tu nueva Cloud TPU.
    ZONE
    Es la zona en la que debes crear la Cloud TPU.
    accelerator-type
    El tipo de acelerador especifica la versión y el tamaño de la Cloud TPU que deseas crear. Para obtener más información, consulta Versiones de TPU.
    version
    La versión del software de TPU que quieres usar. Para obtener más información, consulta Imágenes de VM de TPU.
  2. Usa el siguiente comando para crear una VM de TPU v5p con 8 núcleos:

    gcloud compute tpus tpu-vm create $TPU_NAME \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE  \
        --version=$VERSION
  3. Conéctate a la VM de TPU con el siguiente comando:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

Si usas GKE, consulta la guía de KubeRay en GKE para obtener información sobre la configuración.

Requisitos de instalación

Ejecuta los siguientes comandos en tu VM de TPU para instalar las dependencias requeridas:

  1. Guarda lo siguiente en un archivo, por ejemplo, requirements.txt:

    --find-links https://storage.googleapis.com/libtpu-releases/index.html
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html
    torch~=2.6.0
    torch_xla[tpu]~=2.6.0
    ray[default]==2.40.0
    
  2. Ejecuta el siguiente comando para instalar las dependencias requeridas:

    pip install -r requirements.txt
    

Si ejecutas tu carga de trabajo en GKE, te recomendamos que crees un Dockerfile que instale las dependencias requeridas. Para ver un ejemplo, consulta Ejecuta tu carga de trabajo en nodos de porción de TPU en la documentación de GKE.

Ejecuta una carga de trabajo de PyTorch/XLA en un solo dispositivo

En el siguiente ejemplo, se muestra cómo crear un tensor de XLA en un solo dispositivo, que es un chip de TPU. Esto es similar a la forma en que PyTorch controla otros tipos de dispositivos.

  1. Guarda el siguiente fragmento de código en un archivo, por ejemplo, workload.py:

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    t = torch.randn(2, 2, device=xm.xla_device())
    print(t.device)
    print(t)
    

    La sentencia de importación import torch_xla inicializa PyTorch/XLA, y la función xm.xla_device() muestra el dispositivo XLA actual, un chip TPU.

  2. Establece la variable de entorno PJRT_DEVICE como TPU:

    export PJRT_DEVICE=TPU
    
  3. Ejecuta la secuencia de comandos:

    python workload.py
    

    El resultado es similar al siguiente. Asegúrate de que el resultado indique que se encontró el dispositivo XLA.

    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

Ejecuta PyTorch/XLA en varios dispositivos

  1. Actualiza el fragmento de código de la sección anterior para que se ejecute en varios dispositivos:

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    def _mp_fn(index):
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    if __name__ == '__main__':
        torch_xla.launch(_mp_fn, args=())
    
  2. Ejecuta la secuencia de comandos:

    python workload.py
    

    Si ejecutas el fragmento de código en una TPU v5p-8, el resultado será similar al siguiente:

    xla:0
    xla:0
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    

torch_xla.launch() toma dos argumentos, una función y una lista de parámetros. Crea un proceso para cada dispositivo XLA disponible y llama a la función especificada en los argumentos. En este ejemplo, hay 4 dispositivos de TPU disponibles, por lo que torch_xla.launch() crea 4 procesos y llama a _mp_fn() en cada dispositivo. Cada proceso solo tiene acceso a un dispositivo, por lo que cada dispositivo tiene el índice 0 y se imprime xla:0 para todos los procesos.

Ejecuta PyTorch/XLA en varios hosts con Ray

En las siguientes secciones, se muestra cómo ejecutar el mismo fragmento de código en una porción de TPU de varios hosts más grande. Para obtener más información sobre la arquitectura de TPU de varios hosts, consulta Arquitectura del sistema.

En este ejemplo, configurarás Ray de forma manual. Si ya conoces la configuración de Ray, puedes pasar a la última sección, Ejecuta una carga de trabajo de Ray. Para obtener más información sobre cómo configurar Ray para un entorno de producción, consulta los siguientes recursos:

Crea una VM de TPU de varios hosts

  1. Crea variables de entorno para los parámetros de creación de TPU:

    export TPU_NAME_MULTIHOST=TPU_NAME_MULTIHOST
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE_MULTIHOST=v5p-16
    export VERSION=v2-alpha-tpuv5
  2. Crea una TPU v5p de varios hosts con 2 hosts (una v5p-16, con 4 chips de TPU en cada host) con el siguiente comando:

    gcloud compute tpus tpu-vm create $TPU_NAME_MULTIHOST \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE_MULTIHOST \
        --version=$VERSION

Configura Ray

Una TPU v5p-16 tiene 2 hosts de TPU, cada uno con 4 chips TPU. En este ejemplo, iniciarás el nodo principal de Ray en un host y agregarás el segundo host como nodo trabajador al clúster de Ray.

  1. Conéctate al primer host con SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=0
  2. Instala las dependencias con el mismo archivo de requisitos que en la sección Instalar requisitos:

    pip install -r requirements.txt
    
  3. Inicia el proceso de Ray:

    ray start --head --port=6379
    

    El resultado es similar al siguiente:

    Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y
    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.
    
    Local node IP: 10.130.0.76
    
    --------------------
    Ray runtime started.
    --------------------
    
    Next steps
    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    
    To connect to this Ray cluster:
        import ray
        ray.init()
    
    To terminate the Ray runtime, run
        ray stop
    
    To view the status of the cluster, use
        ray status
    

    Este host de TPU ahora es el nodo principal de Ray. Anota las líneas que muestran cómo agregar otro nodo al clúster de Ray, de manera similar a la siguiente:

    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    

    Usarás este comando en un paso posterior.

  4. Verifica el estado del clúster de Ray:

    ray status
    

    El resultado es similar al siguiente:

    ======== Autoscaler status: 2025-01-14 22:03:39.385610 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/208.0 CPU
    0.0/4.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/268.44GiB memory
    0B/119.04GiB object_store_memory
    0.0/1.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    El clúster solo contiene 4 TPU (0.0/4.0 TPU) porque solo agregaste el nodo principal hasta el momento.

Ahora que el nodo principal se está ejecutando, puedes agregar el segundo host al clúster.

  1. Conéctate al segundo host con SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=1
  2. Instala las dependencias con el mismo archivo de requisitos que en la sección Instalar requisitos:

    pip install -r requirements.txt
    
  3. Inicia el proceso de Ray. Usa el comando del resultado del comando ray start para agregar este nodo al clúster de Ray existente. Asegúrate de reemplazar la dirección IP y el puerto en el siguiente comando:

    ray start --address='10.130.0.76:6379'

    El resultado es similar al siguiente:

    Local node IP: 10.130.0.80
    [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    
    --------------------
    Ray runtime started.
    --------------------
    
    To terminate the Ray runtime, run
    ray stop
    
  4. Vuelve a verificar el estado de Ray:

    ray status
    

    El resultado es similar al siguiente:

    ======== Autoscaler status: 2025-01-14 22:45:21.485617 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/416.0 CPU
    0.0/8.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/546.83GiB memory
    0B/238.35GiB object_store_memory
    0.0/2.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    El segundo host de TPU ahora es un nodo en el clúster. La lista de recursos disponibles ahora muestra 8 TPU (0.0/8.0 TPU).

Ejecuta una carga de trabajo de Ray

  1. Actualiza el fragmento de código para que se ejecute en el clúster de Ray:

    import os
    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    import ray
    
    import torch.distributed as dist
    import torch_xla.runtime as xr
    from torch_xla._internal import pjrt
    
    # Defines the local PJRT world size, the number of processes per host
    LOCAL_WORLD_SIZE = 4
    # Defines the number of hosts in the Ray cluster
    NUM_OF_HOSTS = 2
    GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS
    
    def init_env():
        local_rank = int(os.environ['TPU_VISIBLE_CHIPS'])
    
        pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE)
        xr._init_world_size_ordinal()
    
    # This decorator signals to Ray that the print_tensor() function should be run on a single TPU chip
    @ray.remote(resources={"TPU": 1})
    def print_tensor():
        # Initializes the runtime environment on each Ray worker. Equivalent to
        # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section.
        init_env()
    
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    ray.init()
    
    # Uses Ray to dispatch the function call across available nodes in the cluster
    tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] 
    ray.get(tasks)
    
    ray.shutdown()
    
  2. Ejecuta la secuencia de comandos en el nodo principal de Ray. Reemplaza ray-workload.py por la ruta de acceso a tu secuencia de comandos.

    python ray-workload.py

    El resultado es similar a este:

    WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
    xla:0
    xla:0
    xla:0
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

    El resultado indica que se llamó correctamente a la función en cada dispositivo XLA (8 dispositivos en este ejemplo) en la porción de TPU de varios hosts.

Modo centrado en el host (JAX)

En las siguientes secciones, se describe el modo centrado en el host con JAX. JAX utiliza un paradigma de programación funcional y admite semánticas de un solo programa y varios datos (SPMD) de nivel superior. En lugar de que cada proceso interactúe con un solo dispositivo XLA, el código JAX está diseñado para funcionar en varios dispositivos en un solo host de forma simultánea.

JAX está diseñado para la computación de alto rendimiento y puede usar TPU de manera eficiente para el entrenamiento y la inferencia a gran escala. Este modo es ideal si conoces los conceptos de programación funcional para que puedas aprovechar todo el potencial de JAX.

En estas instrucciones, se supone que ya tienes configurado un entorno de Ray y TPU, incluido un entorno de software que incluye JAX y otros paquetes relacionados. Para crear un clúster de Ray TPU, sigue las instrucciones que se indican en Cómo iniciar unGoogle Cloud clúster de GKE con TPU para KubeRay. Para obtener más información sobre el uso de TPU con KubeRay, consulta Cómo usar TPU con KubeRay.

Ejecuta una carga de trabajo de JAX en una TPU de host único

En la siguiente secuencia de comandos de ejemplo, se muestra cómo ejecutar una función de JAX en un clúster de Ray con una TPU de un solo host, como una v6e-4. Si tienes una TPU de varios hosts, esta secuencia de comandos deja de responder debido al modelo de ejecución de varios controladores de JAX. Para obtener más información sobre cómo ejecutar Ray en una TPU de varios hosts, consulta Cómo ejecutar una carga de trabajo de JAX en una TPU de varios hosts.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

h = my_function.remote()
print(ray.get(h)) # => 4

Si estás acostumbrado a ejecutar Ray con GPUs, hay algunas diferencias clave cuando se usan TPU:

  • En lugar de configurar num_gpus, especificas TPU como un recurso personalizado y estableces la cantidad de chips TPU.
  • Especificas la TPU con la cantidad de chips por nodo de trabajo de Ray. Por ejemplo, si usas una v6e-4, ejecutar una función remota con TPU establecida en 4 consume todo el host de TPU.
    • Esto es diferente de la forma en que se ejecutan las GPU, con un proceso por host. No se recomienda establecer TPU en un número que no sea 4.
    • Excepción: Si tienes un v6e-8 o v5litepod-8 de un solo host, debes establecer este valor en 8.

Ejecuta una carga de trabajo de JAX en una TPU de varios hosts

En la siguiente secuencia de comandos de ejemplo, se muestra cómo ejecutar una función de JAX en un clúster de Ray con una TPU de varios hosts. La secuencia de comandos de ejemplo usa una v6e-16.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]

Si estás acostumbrado a ejecutar Ray con GPUs, hay algunas diferencias clave cuando se usan TPU:

  • Similar a las cargas de trabajo de PyTorch en GPUs:
  • A diferencia de las cargas de trabajo de PyTorch en GPUs, JAX tiene una vista global de los dispositivos disponibles en el clúster.

Ejecuta una carga de trabajo de JAX de Multislice

Multislice te permite ejecutar cargas de trabajo que abarcan varias porciones de TPU dentro de un solo pod de TPU o en varios pods a través de la red del centro de datos.

Puedes usar el paquete ray-tpu para simplificar las interacciones de Ray con las porciones de TPU. Instala ray-tpu con pip:

pip install ray-tpu

En la siguiente secuencia de comandos de ejemplo, se muestra cómo usar el paquete ray-tpu para ejecutar cargas de trabajo de Multislice con actores o tareas de Ray:

from ray_tpu import RayTpuManager
import jax
import ray

ray.init()

# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
    def get_devices(self):
        return jax.device_count()

# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
    return jax.device_count()

tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus) 
"""
TPU resources:
{'v6e-16': [
    RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
    RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""

# if using actors
actors = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=MyActor,
    multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]

# if using tasks
h = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]

# note - you can also run this without Multislice
h = RayTpuManager.run_task(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]

Organiza cargas de trabajo con Ray y MaxText

En esta sección, se describe cómo usar Ray para organizar cargas de trabajo con MaxText, una biblioteca de código abierto escalable y de alto rendimiento para entrenar LLM con JAX y XLA.

MaxText contiene una secuencia de comandos de entrenamiento, train.py, que se debe ejecutar en cada host de TPU. Esto es similar a otras cargas de trabajo de aprendizaje automático de SPMD. Puedes lograr esto con el paquete ray-tpu y crear un wrapper alrededor de la función principal train.py. En los siguientes pasos, se muestra cómo usar el paquete ray-tpu para ejecutar MaxText en una TPU v4-16.

  1. Establece variables de entorno para los parámetros de creación de TPU:

    export TPU_NAME=TPU_NAME
    export ZONE=ZONE
    export ACCELERATOR_TYPE=v6e-16
    export VERSION=v2-alpha-tpuv6e
  2. Crea una TPU v6e-16:

    gcloud compute tpus tpu-vm create $TPU_NAME \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$VERSION
  3. Clona el repositorio de MaxText en todos los trabajadores de TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="git clone https://github.com/AI-Hypercomputer/maxtext"
  4. Instala los requisitos de MaxText en todos los trabajadores de TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="pip install -r maxtext/requirements.txt"
  5. Instala el paquete ray-tpu en todos los trabajadores de TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="pip install ray-tpu"
  6. Conéctate al trabajador 0 con SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=0
  7. Guarda la siguiente secuencia de comandos en un archivo llamado ray_trainer.py en el directorio ~/maxtext/MaxText. Esta secuencia de comandos usa el paquete ray-tpu y crea un wrapper alrededor de la función principal train.py de MaxText.

    import ray
    import ray_tpu
    from train import main as maxtext_main
    
    import logging
    from typing import Sequence
    from absl import app
    
    # Default env vars that run on all TPU VMs.
    MACHINE_ENV_VARS = {
        "ENABLE_PJRT_COMPATIBILITY": "true",
        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",  # Dumps HLOs for debugging
    }
    
    def setup_loggers():
        """Sets up loggers for Ray."""
        logging.basicConfig(level=logging.INFO)
    
    @ray_tpu.remote(
        topology={"v4-16": 1},
    )
    def run_maxtext_train(argv: Sequence[str]):
        maxtext_main(argv=argv)
    
    def main(argv: Sequence[str]):
        ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers))
    
        logging.info(f"argv: {argv}")
    
        try:
            ray.get(run_maxtext_train(argv=argv))
        except Exception as e:
            logging.error("Caught error during training: %s", e)
            logging.error("Shutting down...")
            ray.shutdown()
            raise e
    
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        app.run(main)
    
  8. Ejecuta la secuencia de comandos mediante el siguiente comando:

        python maxtext/MaxText/ray_trainer.py maxtext/MaxText/configs/base.yml \
            base_output_directory=/tmp/maxtext \
            dataset_type=synthetic \
            per_device_batch_size=2 \
            max_target_length=8192 \
            model_name=default \
            steps=100 \
            run_name=test
    

    El resultado es similar al siguiente:

    (run_maxtext_train pid=78967, ip=10.130.0.11) Started an asynchronous checkpoint save for step 0
    (run_maxtext_train pid=78967, ip=10.130.0.11)
    (run_maxtext_train pid=78967, ip=10.130.0.11) Memstats: After params initialized:
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_4(process=1,(0,0,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_5(process=1,(1,0,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_6(process=1,(0,1,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_7(process=1,(1,1,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 0, seconds: 11.775, TFLOP/s/device: 13.153, Tokens/s/device: 1391.395, total_weights: 131072, loss: 12.066
    (run_maxtext_train pid=80538, ip=10.130.0.12)
    (run_maxtext_train pid=80538, ip=10.130.0.12) To see full metrics 'tensorboard --logdir=/tmp/maxtext/test/tensorboard/'
    (run_maxtext_train pid=80538, ip=10.130.0.12) Waiting for step 0 to finish before checkpoint...
    (run_maxtext_train pid=80538, ip=10.130.0.12) Waited 0.7087039947509766 seconds for step 0 to finish before starting checkpointing.
    (run_maxtext_train pid=80538, ip=10.130.0.12) Started an asynchronous checkpoint save for step 0
    (run_maxtext_train pid=80538, ip=10.130.0.12) Memstats: After params initialized:
    (run_maxtext_train pid=80538, ip=10.130.0.12)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_3(process=0,(1,1,0,0)) [repeated 4x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 4, seconds: 1.116, TFLOP/s/device: 138.799, Tokens/s/device: 14683.240, total_weights: 131072, loss: 0.000 [repeated 9x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 9, seconds: 1.068, TFLOP/s/device: 145.065, Tokens/s/device: 15346.083, total_weights: 131072, loss: 0.000 [repeated 9x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 14, seconds: 1.116, TFLOP/s/device: 138.754, Tokens/s/device: 14678.439, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    
    ...
    
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 89, seconds: 1.116, TFLOP/s/device: 138.760, Tokens/s/device: 14679.083, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 94, seconds: 1.091, TFLOP/s/device: 141.924, Tokens/s/device: 15013.837, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 99, seconds: 1.116, TFLOP/s/device: 138.763, Tokens/s/device: 14679.412, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) Output size: 1657041920, temp size: 4907988480, argument size: 1657366016, host temp size: 0, in bytes.
    I0121 01:39:46.830807 130655182204928 ray_trainer.py:47] Training complete!
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 99, seconds: 1.191, TFLOP/s/device: 130.014, Tokens/s/device: 13753.874, total_weights: 131072, loss: 0.000
    

Recursos de TPU y Ray

Ray trata las TPU de manera diferente a las GPUs para adaptarse a la diferencia de uso. En el siguiente ejemplo, hay nueve nodos de Ray en total:

  • El nodo principal de Ray se ejecuta en una VM n1-standard-16.
  • Los nodos de trabajo de Ray se ejecutan en dos TPU v6e-16. Cada TPU constituye cuatro trabajadores.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

Descripciones de los campos de uso de recursos:

  • CPU: Es la cantidad total de CPUs disponibles en el clúster.
  • TPU: Es la cantidad de chips TPU en el clúster.
  • TPU-v6e-16-head: Es un identificador especial para el recurso que corresponde al trabajador 0 de una porción de TPU. Esto es importante para acceder a las porciones individuales de TPU.
  • memory: Es la memoria del montón de trabajo que usa tu aplicación.
  • object_store_memory: Es la memoria que se usa cuando tu aplicación crea objetos en el almacén de objetos con ray.put y cuando muestra valores de funciones remotas.
  • tpu-group-0 y tpu-group-1: Son identificadores únicos para las porciones individuales de TPU. Esto es importante para ejecutar trabajos en rebanadas. Estos campos se establecen en 4 porque hay 4 hosts por porción de TPU en una v6e-16.