Évoluer les charges de travail de ML à l'aide de Ray

Ce document explique comment exécuter des charges de travail de machine learning (ML) avec Ray et JAX sur des TPU. Il existe deux modes d'utilisation des TPU avec Ray : le mode axé sur l'appareil (PyTorch/XLA) et le mode axé sur l'hôte (JAX).

Ce document suppose que vous avez déjà configuré un environnement TPU. Pour en savoir plus, consultez les ressources suivantes:

Mode axé sur l'appareil (PyTorch/XLA)

Le mode axé sur l'appareil conserve une grande partie du style programmatique de PyTorch classique. Dans ce mode, vous ajoutez un type d'appareil XLA, qui fonctionne comme n'importe quel autre appareil PyTorch. Chaque processus interagit avec un seul appareil XLA.

Ce mode est idéal si vous connaissez déjà PyTorch avec des GPU et que vous souhaitez utiliser des abstractions de codage similaires.

Les sections suivantes décrivent comment exécuter une charge de travail PyTorch/XLA sur un ou plusieurs appareils sans utiliser Ray, puis comment exécuter la même charge de travail sur plusieurs hôtes à l'aide de Ray.

Créer un TPU

  1. Créez des variables d'environnement pour les paramètres de création de TPU:

    export TPU_NAME=TPU_NAME
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-8
    export VERSION=v2-alpha-tpuv5

    Descriptions des variables d'environnement

    TPU_NAME
    Nom de votre nouveau Cloud TPU.
    ZONE
    Zone dans laquelle créer votre Cloud TPU.
    accelerator-type
    Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus, consultez la section Versions de TPU.
    version
    Version logicielle du TPU que vous souhaitez utiliser. Pour en savoir plus, consultez la section Images de VM TPU.
  2. Utilisez la commande suivante pour créer une VM TPU v5p avec huit cœurs:

    gcloud compute tpus tpu-vm create $TPU_NAME \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE  \
        --version=$VERSION
  3. Connectez-vous à la VM TPU à l'aide de la commande suivante:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

Si vous utilisez GKE, consultez le guide KubeRay sur GKE pour en savoir plus sur la configuration.

Configuration requise

Exécutez les commandes suivantes sur votre VM TPU pour installer les dépendances requises:

  1. Enregistrez le code suivant dans un fichier, par exemple requirements.txt:

    --find-links https://storage.googleapis.com/libtpu-releases/index.html
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html
    torch~=2.6.0
    torch_xla[tpu]~=2.6.0
    ray[default]==2.40.0
    
  2. Exécutez la commande suivante pour installer les dépendances requises:

    pip install -r requirements.txt
    

Si vous exécutez votre charge de travail sur GKE, nous vous recommandons de créer un Dockerfile qui installe les dépendances requises. Pour obtenir un exemple, consultez la section Exécuter votre charge de travail sur des nœuds de tranche TPU dans la documentation GKE.

Exécuter une charge de travail PyTorch/XLA sur un seul appareil

L'exemple suivant montre comment créer un tenseur XLA sur un seul appareil, qui est une puce TPU. Cela ressemble à la façon dont PyTorch gère les autres types d'appareils.

  1. Enregistrez l'extrait de code suivant dans un fichier, par exemple workload.py:

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    t = torch.randn(2, 2, device=xm.xla_device())
    print(t.device)
    print(t)
    

    L'instruction d'importation import torch_xla initialise PyTorch/XLA, et la fonction xm.xla_device() renvoie l'appareil XLA actuel, une puce TPU.

  2. Définissez la variable d'environnement PJRT_DEVICE sur TPU:

    export PJRT_DEVICE=TPU
    
  3. Exécutez le script :

    python workload.py
    

    La sortie doit ressembler à ce qui suit. Assurez-vous que la sortie indique que l'appareil XLA a été détecté.

    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

Exécuter PyTorch/XLA sur plusieurs appareils

  1. Modifiez l'extrait de code de la section précédente pour qu'il s'exécute sur plusieurs appareils:

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    def _mp_fn(index):
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    if __name__ == '__main__':
        torch_xla.launch(_mp_fn, args=())
    
  2. Exécutez le script :

    python workload.py
    

    Si vous exécutez l'extrait de code sur un TPU v5p-8, le résultat ressemble à ceci:

    xla:0
    xla:0
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    

torch_xla.launch() accepte deux arguments, une fonction et une liste de paramètres. Il crée un processus pour chaque appareil XLA disponible et appelle la fonction spécifiée dans les arguments. Dans cet exemple, quatre appareils TPU sont disponibles. torch_xla.launch() crée donc quatre processus et appelle _mp_fn() sur chaque appareil. Chaque processus n'a accès qu'à un seul appareil. Par conséquent, chaque appareil a l'indice 0, et xla:0 est imprimé pour tous les processus.

Exécuter PyTorch/XLA sur plusieurs hôtes avec Ray

Les sections suivantes montrent comment exécuter le même extrait de code sur une tranche TPU multi-hôte plus importante. Pour en savoir plus sur l'architecture TPU multi-hôte, consultez la section Architecture du système.

Dans cet exemple, vous allez configurer manuellement Ray. Si vous savez déjà configurer Ray, vous pouvez passer à la dernière section, Exécuter une charge de travail Ray. Pour en savoir plus sur la configuration de Ray pour un environnement de production, consultez les ressources suivantes:

Créer une VM TPU multi-hôte

  1. Créez des variables d'environnement pour les paramètres de création de TPU:

    export TPU_NAME_MULTIHOST=TPU_NAME_MULTIHOST
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE_MULTIHOST=v5p-16
    export VERSION=v2-alpha-tpuv5
  2. Créez un TPU v5p multi-hôte avec deux hôtes (un v5p-16, avec quatre puces TPU sur chaque hôte) à l'aide de la commande suivante:

    gcloud compute tpus tpu-vm create $TPU_NAME_MULTIHOST \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE_MULTIHOST \
        --version=$VERSION

Configurer Ray

Un TPU v5p-16 comporte deux hôtes TPU, chacun avec quatre puces TPU. Dans cet exemple, vous allez démarrer le nœud principal Ray sur un hôte et ajouter le deuxième hôte en tant que nœud de travail au cluster Ray.

  1. Connectez-vous au premier hôte à l'aide de SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=0
  2. Installez les dépendances avec le même fichier d'exigences que dans la section Exigences d'installation:

    pip install -r requirements.txt
    
  3. Démarrez le processus Ray:

    ray start --head --port=6379
    

    La sortie ressemble à ceci :

    Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y
    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.
    
    Local node IP: 10.130.0.76
    
    --------------------
    Ray runtime started.
    --------------------
    
    Next steps
    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    
    To connect to this Ray cluster:
        import ray
        ray.init()
    
    To terminate the Ray runtime, run
        ray stop
    
    To view the status of the cluster, use
        ray status
    

    Cet hôte TPU est désormais le nœud principal de Ray. Notez les lignes qui montrent comment ajouter un autre nœud au cluster Ray, comme suit:

    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    

    Vous utiliserez cette commande dans une prochaine étape.

  4. Vérifiez l'état du cluster Ray:

    ray status
    

    La sortie ressemble à ceci :

    ======== Autoscaler status: 2025-01-14 22:03:39.385610 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/208.0 CPU
    0.0/4.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/268.44GiB memory
    0B/119.04GiB object_store_memory
    0.0/1.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    Le cluster ne contient que quatre TPU (0.0/4.0 TPU), car vous n'avez ajouté que le nœud racine jusqu'à présent.

Maintenant que le nœud principal est en cours d'exécution, vous pouvez ajouter le deuxième hôte au cluster.

  1. Connectez-vous au deuxième hôte à l'aide de SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=1
  2. Installez les dépendances avec le même fichier d'exigences que dans la section Exigences d'installation:

    pip install -r requirements.txt
    
  3. Démarrez le processus Ray. Utilisez la commande de la sortie de la commande ray start pour ajouter ce nœud au cluster Ray existant. Veillez à remplacer l'adresse IP et le port dans la commande suivante:

    ray start --address='10.130.0.76:6379'

    La sortie ressemble à ceci :

    Local node IP: 10.130.0.80
    [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    
    --------------------
    Ray runtime started.
    --------------------
    
    To terminate the Ray runtime, run
    ray stop
    
  4. Vérifiez à nouveau l'état de Ray:

    ray status
    

    La sortie ressemble à ceci :

    ======== Autoscaler status: 2025-01-14 22:45:21.485617 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/416.0 CPU
    0.0/8.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/546.83GiB memory
    0B/238.35GiB object_store_memory
    0.0/2.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    Le deuxième hôte TPU est désormais un nœud du cluster. La liste des ressources disponibles affiche désormais huit TPU (0.0/8.0 TPU).

Exécuter une charge de travail Ray

  1. Modifiez l'extrait de code pour qu'il s'exécute sur le cluster Ray:

    import os
    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    import ray
    
    import torch.distributed as dist
    import torch_xla.runtime as xr
    from torch_xla._internal import pjrt
    
    # Defines the local PJRT world size, the number of processes per host
    LOCAL_WORLD_SIZE = 4
    # Defines the number of hosts in the Ray cluster
    NUM_OF_HOSTS = 2
    GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS
    
    def init_env():
        local_rank = int(os.environ['TPU_VISIBLE_CHIPS'])
    
        pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE)
        xr._init_world_size_ordinal()
    
    # This decorator signals to Ray that the print_tensor() function should be run on a single TPU chip
    @ray.remote(resources={"TPU": 1})
    def print_tensor():
        # Initializes the runtime environment on each Ray worker. Equivalent to
        # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section.
        init_env()
    
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    ray.init()
    
    # Uses Ray to dispatch the function call across available nodes in the cluster
    tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] 
    ray.get(tasks)
    
    ray.shutdown()
    
  2. Exécutez le script sur le nœud principal de Ray. Remplacez ray-workload.py par le chemin d'accès à votre script.

    python ray-workload.py

    Le résultat ressemble à ce qui suit :

    WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
    xla:0
    xla:0
    xla:0
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

    La sortie indique que la fonction a bien été appelée sur chaque appareil XLA (huit appareils dans cet exemple) de la tranche TPU multi-hôte.

Mode hôte centré (JAX)

Les sections suivantes décrivent le mode hôte centré avec JAX. JAX utilise un paradigme de programmation fonctionnelle et prend en charge la sémantique SPMD (single program, multiple data) de niveau supérieur. Au lieu de laisser chaque processus interagir avec un seul appareil XLA, le code JAX est conçu pour fonctionner simultanément sur plusieurs appareils d'un même hôte.

JAX est conçu pour le calcul hautes performances et peut utiliser efficacement les TPU pour l'entraînement et l'inférence à grande échelle. Ce mode est idéal si vous connaissez les concepts de programmation fonctionnelle afin de pouvoir exploiter tout le potentiel de JAX.

Ces instructions supposent que vous avez déjà configuré un environnement Ray et TPU, y compris un environnement logiciel qui inclut JAX et d'autres packages associés. Pour créer un cluster TPU Ray, suivez les instructions de la section Démarrer un clusterGoogle Cloud GKE avec des TPU pour KubeRay. Pour en savoir plus sur l'utilisation des TPU avec KubeRay, consultez Utiliser des TPU avec KubeRay.

Exécuter une charge de travail JAX sur un TPU à hôte unique

L'exemple de script suivant montre comment exécuter une fonction JAX sur un cluster Ray avec un TPU à hôte unique, tel qu'un v6e-4. Si vous disposez d'un TPU multi-hôte, ce script cesse de répondre en raison du modèle d'exécution multicontrôleur de JAX. Pour en savoir plus sur l'exécution de Ray sur un TPU multi-hôte, consultez Exécuter une charge de travail JAX sur un TPU multi-hôte.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

h = my_function.remote()
print(ray.get(h)) # => 4

Si vous avez l'habitude d'exécuter Ray avec des GPU, vous constaterez quelques différences clés lorsque vous utiliserez des TPU:

  • Au lieu de définir num_gpus, vous spécifiez TPU en tant que ressource personnalisée et définissez le nombre de puces TPU.
  • Vous spécifiez le TPU à l'aide du nombre de puces par nœud de travail Ray. Par exemple, si vous utilisez un v6e-4, l'exécution d'une fonction distante avec TPU défini sur 4 consomme l'hôte TPU entier.
    • Cela diffère de la façon dont les GPU s'exécutent généralement, avec un processus par hôte. Il est déconseillé de définir TPU sur un nombre autre que 4.
    • Exception: Si vous disposez d'un v6e-8 ou d'un v5litepod-8 à hôte unique, vous devez définir cette valeur sur 8.

Exécuter une charge de travail JAX sur un TPU multi-hôte

L'exemple de script suivant montre comment exécuter une fonction JAX sur un cluster Ray avec un TPU multi-hôte. L'exemple de script utilise une version v6e-16.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]

Si vous avez l'habitude d'exécuter Ray avec des GPU, vous constaterez quelques différences clés lorsque vous utiliserez des TPU:

  • Comme pour les charges de travail PyTorch sur les GPU :
  • Contrairement aux charges de travail PyTorch sur les GPU, JAX dispose d'une vue globale des appareils disponibles dans le cluster.

Exécuter une charge de travail JAX Multislice

Multislice vous permet d'exécuter des charges de travail couvrant plusieurs tranches de TPU dans un seul pod TPU ou dans plusieurs pods sur le réseau du centre de données.

Vous pouvez utiliser le package ray-tpu pour simplifier les interactions de Ray avec les tranches TPU. Installez ray-tpu à l'aide de pip:

pip install ray-tpu

L'exemple de script suivant montre comment utiliser le package ray-tpu pour exécuter des charges de travail multicouches à l'aide d'acteurs ou de tâches Ray:

from ray_tpu import RayTpuManager
import jax
import ray

ray.init()

# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
    def get_devices(self):
        return jax.device_count()

# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
    return jax.device_count()

tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus) 
"""
TPU resources:
{'v6e-16': [
    RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
    RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""

# if using actors
actors = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=MyActor,
    multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]

# if using tasks
h = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]

# note - you can also run this without Multislice
h = RayTpuManager.run_task(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]

Orchestrer des charges de travail à l'aide de Ray et de MaxText

Cette section explique comment utiliser Ray pour orchestrer des charges de travail à l'aide de MaxText, une bibliothèque Open Source évolutive et hautes performances pour l'entraînement de LLM à l'aide de JAX et XLA.

MaxText contient un script d'entraînement, train.py, qui doit s'exécuter sur chaque hôte TPU. Cela est semblable aux autres charges de travail de machine learning SPMD. Pour ce faire, utilisez le package ray-tpu et créez un wrapper autour de la fonction principale train.py. Les étapes suivantes montrent comment utiliser le package ray-tpu pour exécuter MaxText sur un TPU v4-16.

  1. Définissez des variables d'environnement pour les paramètres de création de TPU:

    export TPU_NAME=TPU_NAME
    export ZONE=ZONE
    export ACCELERATOR_TYPE=v6e-16
    export VERSION=v2-alpha-tpuv6e
  2. Créez un TPU v6e-16:

    gcloud compute tpus tpu-vm create $TPU_NAME \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$VERSION
  3. Clonez le dépôt MaxText sur tous les nœuds TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="git clone https://github.com/AI-Hypercomputer/maxtext"
  4. Installez les exigences MaxText sur tous les nœuds TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="pip install -r maxtext/requirements.txt"
  5. Installez le package ray-tpu sur tous les nœuds TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="pip install ray-tpu"
  6. Connectez-vous au nœud de calcul 0 à l'aide de SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=0
  7. Enregistrez le script suivant dans un fichier nommé ray_trainer.py dans le répertoire ~/maxtext/MaxText. Ce script utilise le package ray-tpu et crée un wrapper autour de la fonction principale train.py de MaxText.

    import ray
    import ray_tpu
    from train import main as maxtext_main
    
    import logging
    from typing import Sequence
    from absl import app
    
    # Default env vars that run on all TPU VMs.
    MACHINE_ENV_VARS = {
        "ENABLE_PJRT_COMPATIBILITY": "true",
        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",  # Dumps HLOs for debugging
    }
    
    def setup_loggers():
        """Sets up loggers for Ray."""
        logging.basicConfig(level=logging.INFO)
    
    @ray_tpu.remote(
        topology={"v4-16": 1},
    )
    def run_maxtext_train(argv: Sequence[str]):
        maxtext_main(argv=argv)
    
    def main(argv: Sequence[str]):
        ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers))
    
        logging.info(f"argv: {argv}")
    
        try:
            ray.get(run_maxtext_train(argv=argv))
        except Exception as e:
            logging.error("Caught error during training: %s", e)
            logging.error("Shutting down...")
            ray.shutdown()
            raise e
    
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        app.run(main)
    
  8. Exécutez le script en exécutant la commande suivante:

        python maxtext/MaxText/ray_trainer.py maxtext/MaxText/configs/base.yml \
            base_output_directory=/tmp/maxtext \
            dataset_type=synthetic \
            per_device_batch_size=2 \
            max_target_length=8192 \
            model_name=default \
            steps=100 \
            run_name=test
    

    La sortie ressemble à ceci :

    (run_maxtext_train pid=78967, ip=10.130.0.11) Started an asynchronous checkpoint save for step 0
    (run_maxtext_train pid=78967, ip=10.130.0.11)
    (run_maxtext_train pid=78967, ip=10.130.0.11) Memstats: After params initialized:
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_4(process=1,(0,0,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_5(process=1,(1,0,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_6(process=1,(0,1,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_7(process=1,(1,1,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 0, seconds: 11.775, TFLOP/s/device: 13.153, Tokens/s/device: 1391.395, total_weights: 131072, loss: 12.066
    (run_maxtext_train pid=80538, ip=10.130.0.12)
    (run_maxtext_train pid=80538, ip=10.130.0.12) To see full metrics 'tensorboard --logdir=/tmp/maxtext/test/tensorboard/'
    (run_maxtext_train pid=80538, ip=10.130.0.12) Waiting for step 0 to finish before checkpoint...
    (run_maxtext_train pid=80538, ip=10.130.0.12) Waited 0.7087039947509766 seconds for step 0 to finish before starting checkpointing.
    (run_maxtext_train pid=80538, ip=10.130.0.12) Started an asynchronous checkpoint save for step 0
    (run_maxtext_train pid=80538, ip=10.130.0.12) Memstats: After params initialized:
    (run_maxtext_train pid=80538, ip=10.130.0.12)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_3(process=0,(1,1,0,0)) [repeated 4x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 4, seconds: 1.116, TFLOP/s/device: 138.799, Tokens/s/device: 14683.240, total_weights: 131072, loss: 0.000 [repeated 9x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 9, seconds: 1.068, TFLOP/s/device: 145.065, Tokens/s/device: 15346.083, total_weights: 131072, loss: 0.000 [repeated 9x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 14, seconds: 1.116, TFLOP/s/device: 138.754, Tokens/s/device: 14678.439, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    
    ...
    
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 89, seconds: 1.116, TFLOP/s/device: 138.760, Tokens/s/device: 14679.083, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 94, seconds: 1.091, TFLOP/s/device: 141.924, Tokens/s/device: 15013.837, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 99, seconds: 1.116, TFLOP/s/device: 138.763, Tokens/s/device: 14679.412, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) Output size: 1657041920, temp size: 4907988480, argument size: 1657366016, host temp size: 0, in bytes.
    I0121 01:39:46.830807 130655182204928 ray_trainer.py:47] Training complete!
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 99, seconds: 1.191, TFLOP/s/device: 130.014, Tokens/s/device: 13753.874, total_weights: 131072, loss: 0.000
    

Ressources TPU et Ray

Ray traite les TPU différemment des GPU pour tenir compte de la différence d'utilisation. Dans l'exemple suivant, il y a neuf nœuds Ray au total:

  • Le nœud principal Ray s'exécute sur une VM n1-standard-16.
  • Les nœuds de calcul Ray s'exécutent sur deux TPU v6e-16. Chaque TPU constitue quatre nœuds de calcul.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

Descriptions des champs d'utilisation des ressources:

  • CPU: nombre total de processeurs disponibles dans le cluster.
  • TPU: nombre de puces TPU dans le cluster.
  • TPU-v6e-16-head: identifiant spécial de la ressource correspondant au nœud de calcul 0 d'une tranche TPU. Cela est important pour accéder à des tranches de TPU individuelles.
  • memory: mémoire de tas de nœuds de calcul utilisée par votre application.
  • object_store_memory: mémoire utilisée lorsque votre application crée des objets dans le magasin d'objets à l'aide de ray.put et lorsqu'elle renvoie des valeurs à partir de fonctions distantes.
  • tpu-group-0 et tpu-group-1: identifiants uniques des tranches de TPU individuelles. Cela est important pour exécuter des tâches sur des tranches. Ces champs sont définis sur 4, car il y a quatre hôtes par tranche de TPU dans un v6e-16.