Escalonar cargas de trabalho de ML usando o Ray

Este documento fornece detalhes sobre como executar cargas de trabalho de machine learning (ML) com o Ray e o JAX em TPUs. Há dois modos diferentes para usar TPUs com o Ray: modo centrado no dispositivo (PyTorch/XLA) e modo centrado no host (JAX).

Este documento pressupõe que você já tenha configurado um ambiente de TPU. Para mais informações, consulte os seguintes recursos:

Modo centrado no dispositivo (PyTorch/XLA)

O modo centrado no dispositivo mantém grande parte do estilo programático do PyTorch clássico. Nesse modo, você adiciona um novo tipo de dispositivo XLA, que funciona como qualquer outro dispositivo PyTorch. Cada processo individual interage com um dispositivo XLA.

Esse modo é ideal se você já conhece o PyTorch com GPUs e quer usar abstrações de programação semelhantes.

As seções a seguir descrevem como executar uma carga de trabalho do PyTorch/XLA em um ou mais dispositivos sem usar o Ray e, em seguida, como executar a mesma carga de trabalho em vários hosts usando o Ray.

Criar TPU

  1. Crie variáveis de ambiente para os parâmetros de criação de TPU:

    export TPU_NAME=TPU_NAME
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE=v5p-8
    export VERSION=v2-alpha-tpuv5

    Descrições das variáveis de ambiente

    TPU_NAME
    O nome do novo Cloud TPU.
    ZONE
    A zona em que o Cloud TPU será criado.
    accelerator-type
    O tipo de acelerador especifica a versão e o tamanho da Cloud TPU que você quer criar. Para mais informações, consulte Versões da TPU.
    version
    A versão do software de TPU que você quer usar. Para mais informações, consulte Imagens de VM de TPU.
  2. Use o comando a seguir para criar uma VM TPU v5p com 8 núcleos:

    gcloud compute tpus tpu-vm create $TPU_NAME \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE  \
        --version=$VERSION
  3. Conecte-se à VM da TPU usando o seguinte comando:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

Se você estiver usando o GKE, consulte o guia do KubeRay no GKE para informações de configuração.

Requisitos de instalação

Execute os seguintes comandos na VM da TPU para instalar as dependências necessárias:

  1. Salve o seguinte em um arquivo, por exemplo, requirements.txt:

    --find-links https://storage.googleapis.com/libtpu-releases/index.html
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html
    torch~=2.6.0
    torch_xla[tpu]~=2.6.0
    ray[default]==2.40.0
    
  2. Execute o comando a seguir para instalar as dependências necessárias:

    pip install -r requirements.txt
    

Se você estiver executando sua carga de trabalho no GKE, recomendamos criar um Dockerfile que instale as dependências necessárias. Para conferir um exemplo, consulte Executar a carga de trabalho em nós de fatia de TPU na documentação do GKE.

Executar uma carga de trabalho do PyTorch/XLA em um único dispositivo

O exemplo a seguir demonstra como criar um tensor XLA em um único dispositivo, que é um chip TPU. Isso é semelhante à forma como o PyTorch lida com outros tipos de dispositivos.

  1. Salve o snippet de código abaixo em um arquivo, por exemplo, workload.py:

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    t = torch.randn(2, 2, device=xm.xla_device())
    print(t.device)
    print(t)
    

    A instrução de importação import torch_xla inicializa o PyTorch/XLA, e a função xm.xla_device() retorna o dispositivo XLA atual, um chip TPU.

  2. Defina a variável de ambiente PJRT_DEVICE como TPU:

    export PJRT_DEVICE=TPU
    
  3. Execute o script:

    python workload.py
    

    A saída será semelhante à mostrada abaixo. Verifique se a saída indica que o dispositivo XLA foi encontrado.

    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

Executar o PyTorch/XLA em vários dispositivos

  1. Atualize o snippet de código da seção anterior para ser executado em vários dispositivos:

    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    
    def _mp_fn(index):
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    if __name__ == '__main__':
        torch_xla.launch(_mp_fn, args=())
    
  2. Execute o script:

    python workload.py
    

    Se você executar o snippet de código em um TPU v5p-8, a saída será semelhante a esta:

    xla:0
    xla:0
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    xla:0
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    tensor([[ 1.2309,  0.9896],
            [ 0.5820, -1.2950]], device='xla:0')
    

torch_xla.launch() usa dois argumentos, uma função e uma lista de parâmetros. Ele cria um processo para cada dispositivo XLA disponível e chama a função especificada nos argumentos. Neste exemplo, há quatro dispositivos TPU disponíveis, então torch_xla.launch() cria quatro processos e chama _mp_fn() em cada dispositivo. Cada processo tem acesso apenas a um dispositivo, portanto, cada dispositivo tem o índice 0 e xla:0 é impresso para todos os processos.

Executar o PyTorch/XLA em vários hosts com o Ray

As seções a seguir mostram como executar o mesmo snippet de código em uma fatia de TPU multihost maior. Para mais informações sobre a arquitetura de TPU com vários hosts, consulte Arquitetura do sistema.

Neste exemplo, você vai configurar o Ray manualmente. Se você já sabe como configurar o Ray, pule para a última seção, Executar uma carga de trabalho do Ray. Para mais informações sobre como configurar o Ray para um ambiente de produção, consulte os seguintes recursos:

Criar uma VM de TPU com vários hosts

  1. Crie variáveis de ambiente para os parâmetros de criação de TPU:

    export TPU_NAME_MULTIHOST=TPU_NAME_MULTIHOST
    export ZONE=europe-west4-b
    export ACCELERATOR_TYPE_MULTIHOST=v5p-16
    export VERSION=v2-alpha-tpuv5
  2. Crie uma TPU v5p de vários hosts com dois hosts (um v5p-16, com quatro chips de TPU em cada host) usando o seguinte comando:

    gcloud compute tpus tpu-vm create $TPU_NAME_MULTIHOST \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE_MULTIHOST \
        --version=$VERSION

Configurar o Ray

Uma TPU v5p-16 tem dois hosts de TPU, cada um com quatro chips de TPU. Neste exemplo, você vai iniciar o nó principal do Ray em um host e adicionar o segundo host como um nó de trabalho ao cluster do Ray.

  1. Conecte-se ao primeiro host usando SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=0
  2. Instale as dependências com o mesmo arquivo de requisitos da seção "Instalar requisitos":

    pip install -r requirements.txt
    
  3. Inicie o processo do Ray:

    ray start --head --port=6379
    

    A saída será assim:

    Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y
    Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details.
    
    Local node IP: 10.130.0.76
    
    --------------------
    Ray runtime started.
    --------------------
    
    Next steps
    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    
    To connect to this Ray cluster:
        import ray
        ray.init()
    
    To terminate the Ray runtime, run
        ray stop
    
    To view the status of the cluster, use
        ray status
    

    Esse host da TPU agora é o nó principal do Ray. Anote as linhas que mostram como adicionar outro nó ao cluster do Ray, semelhante a este:

    To add another node to this Ray cluster, run
        ray start --address='10.130.0.76:6379'
    

    Você vai usar esse comando em uma etapa posterior.

  4. Verifique o status do cluster do Ray:

    ray status
    

    A saída será assim:

    ======== Autoscaler status: 2025-01-14 22:03:39.385610 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/208.0 CPU
    0.0/4.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/268.44GiB memory
    0B/119.04GiB object_store_memory
    0.0/1.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    O cluster contém apenas quatro TPUs (0.0/4.0 TPU) porque você só adicionou o nó principal até agora.

Agora que o nó principal está em execução, você pode adicionar o segundo host ao cluster.

  1. Conecte-se ao segundo host usando SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=1
  2. Instale as dependências com o mesmo arquivo de requisitos da seção "Instalar requisitos":

    pip install -r requirements.txt
    
  3. Inicie o processo do Ray. Use o comando da saída do comando ray start para adicionar esse nó ao cluster do Ray. Substitua o endereço IP e a porta no comando a seguir:

    ray start --address='10.130.0.76:6379'

    A saída será assim:

    Local node IP: 10.130.0.80
    [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    
    --------------------
    Ray runtime started.
    --------------------
    
    To terminate the Ray runtime, run
    ray stop
    
  4. Verifique o status do Ray novamente:

    ray status
    

    A saída será assim:

    ======== Autoscaler status: 2025-01-14 22:45:21.485617 ========
    Node status
    ---------------------------------------------------------------
    Active:
    1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    Pending:
    (no pending nodes)
    Recent failures:
    (no failures)
    
    Resources
    ---------------------------------------------------------------
    Usage:
    0.0/416.0 CPU
    0.0/8.0 TPU
    0.0/1.0 TPU-v5p-16-head
    0B/546.83GiB memory
    0B/238.35GiB object_store_memory
    0.0/2.0 your-tpu-name
    
    Demands:
    (no resource demands)
    

    O segundo host de TPU agora é um nó no cluster. A lista de recursos disponíveis agora mostra 8 TPUs (0.0/8.0 TPU).

Executar uma carga de trabalho do Ray

  1. Atualize o snippet de código para ser executado no cluster do Ray:

    import os
    import torch
    import torch_xla
    import torch_xla.core.xla_model as xm
    import ray
    
    import torch.distributed as dist
    import torch_xla.runtime as xr
    from torch_xla._internal import pjrt
    
    # Defines the local PJRT world size, the number of processes per host
    LOCAL_WORLD_SIZE = 4
    # Defines the number of hosts in the Ray cluster
    NUM_OF_HOSTS = 2
    GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS
    
    def init_env():
        local_rank = int(os.environ['TPU_VISIBLE_CHIPS'])
    
        pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE)
        xr._init_world_size_ordinal()
    
    # This decorator signals to Ray that the print_tensor() function should be run on a single TPU chip
    @ray.remote(resources={"TPU": 1})
    def print_tensor():
        # Initializes the runtime environment on each Ray worker. Equivalent to
        # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section.
        init_env()
    
        t = torch.randn(2, 2, device=xm.xla_device())
        print(t.device)
        print(t)
    
    ray.init()
    
    # Uses Ray to dispatch the function call across available nodes in the cluster
    tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] 
    ray.get(tasks)
    
    ray.shutdown()
    
  2. Execute o script no nó principal do Ray. Substitua ray-workload.py pelo caminho do script.

    python ray-workload.py

    O resultado será assim:

    WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
    xla:0
    xla:0
    xla:0
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    xla:0
    tensor([[ 0.6220, -1.4707],
            [-1.2112,  0.7024]], device='xla:0')
    

    A saída indica que a função foi chamada em cada dispositivo XLA (8 dispositivos neste exemplo) na fração de TPU de vários hosts.

Modo centrado no host (JAX)

As seções a seguir descrevem o modo centrado no host com o JAX. O JAX utiliza um paradigma de programação funcional e oferece suporte a semântica de programa único de nível superior, dados múltiplos (SPMD, na sigla em inglês). Em vez de cada processo interagir com um único dispositivo XLA, o código JAX foi projetado para operar em vários dispositivos em um único host simultaneamente.

O JAX foi projetado para computação de alto desempenho e pode usar TPUs de maneira eficiente para treinamento e inferência em grande escala. Esse modo é ideal se você já conhece os conceitos de programação funcional para aproveitar todo o potencial do JAX.

Estas instruções pressupõem que você já tenha configurado um ambiente Ray e TPU, incluindo um ambiente de software que inclui JAX e outros pacotes relacionados. Para criar um cluster do Ray TPU, siga as instruções em Iniciar Google Cloud cluster do GKE com TPUs para KubeRay. Para mais informações sobre o uso de TPUs com o KubeRay, consulte Usar TPUs com o KubeRay.

Executar uma carga de trabalho do JAX em uma TPU de host único

O exemplo de script abaixo demonstra como executar uma função JAX em um cluster do Ray com um TPU de host único, como um v6e-4. Se você tiver uma TPU de vários hosts, esse script vai parar de responder devido ao modelo de execução de vários controladores do JAX. Para mais informações sobre como executar o Ray em uma TPU de vários hosts, consulte Executar uma carga de trabalho do JAX em uma TPU de vários hosts.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

h = my_function.remote()
print(ray.get(h)) # => 4

Se você está acostumado a executar o Ray com GPUs, há algumas diferenças importantes ao usar TPUs:

  • Em vez de definir num_gpus, especifique TPU como um recurso personalizado e defina o número de chips de TPU.
  • Especifique a TPU usando o número de chips por nó de worker do Ray. Por exemplo, se você estiver usando uma v6e-4, a execução de uma função remota com TPU definido como 4 consome todo o host TPU.
    • Isso é diferente da forma como as GPUs normalmente são executadas, com um processo por host. Não é recomendável definir TPU como um número diferente de 4.
    • Exceção: se você tiver um v6e-8 ou v5litepod-8 de host único, defina esse valor como 8.

Executar uma carga de trabalho do JAX em uma TPU de vários hosts

O exemplo de script abaixo demonstra como executar uma função JAX em um cluster do Ray com uma TPU de vários hosts. O script de exemplo usa uma v6e-16.

import ray
import jax

@ray.remote(resources={"TPU": 4})
def my_function() -> int:
    return jax.device_count()

num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]

Se você está acostumado a executar o Ray com GPUs, há algumas diferenças importantes ao usar TPUs:

Executar uma carga de trabalho JAX Multislice

O Multislice permite executar cargas de trabalho que abrangem várias frações de TPU em um único pod de TPU ou em vários pods na rede do data center.

Você pode usar o pacote ray-tpu para simplificar as interações do Ray com fatias de TPU. Instale ray-tpu usando pip:

pip install ray-tpu

O exemplo de script a seguir mostra como usar o pacote ray-tpu para executar cargas de trabalho de várias fatias usando atores ou tarefas do Ray:

from ray_tpu import RayTpuManager
import jax
import ray

ray.init()

# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
    def get_devices(self):
        return jax.device_count()

# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
    return jax.device_count()

tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus) 
"""
TPU resources:
{'v6e-16': [
    RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
    RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""

# if using actors
actors = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=MyActor,
    multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]

# if using tasks
h = RayTpuManager.remote(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]

# note - you can also run this without Multislice
h = RayTpuManager.run_task(
    tpus=tpus["v6e-16"],
    actor_or_fn=get_devices,
    multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]

Orquestrar cargas de trabalho usando o Ray e o MaxText

Esta seção descreve como usar o Ray para orquestrar cargas de trabalho usando MaxText, uma biblioteca de código aberto escalonável e de alto desempenho para treinamento de LLMs usando JAX e XLA.

O MaxText contém um script de treinamento, train.py, que precisa ser executado em cada host de TPU. Isso é semelhante a outras cargas de trabalho de machine learning SPMD. É possível fazer isso usando o pacote ray-tpu e criando um wrapper em torno da função principal train.py. As etapas a seguir mostram como usar o pacote ray-tpu para executar o MaxText em uma TPU v4-16.

  1. Defina variáveis de ambiente para os parâmetros de criação de TPU:

    export TPU_NAME=TPU_NAME
    export ZONE=ZONE
    export ACCELERATOR_TYPE=v6e-16
    export VERSION=v2-alpha-tpuv6e
  2. Crie uma TPU v6e-16:

    gcloud compute tpus tpu-vm create $TPU_NAME \
        --zone=$ZONE \
        --accelerator-type=$ACCELERATOR_TYPE \
        --version=$VERSION
  3. Clone o repositório MaxText em todos os workers da TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="git clone https://github.com/AI-Hypercomputer/maxtext"
  4. Instale os requisitos do MaxText em todos os workers do TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="pip install -r maxtext/requirements.txt"
  5. Instale o pacote ray-tpu em todos os workers de TPU:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=all \
        --command="pip install ray-tpu"
  6. Conecte-se ao worker 0 usando SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME \
        --zone=$ZONE \
        --worker=0
  7. Salve o script abaixo em um arquivo chamado ray_trainer.py no diretório ~/maxtext/MaxText. Esse script usa o pacote ray-tpu e cria um wrapper em torno da função principal train.py do MaxText.

    import ray
    import ray_tpu
    from train import main as maxtext_main
    
    import logging
    from typing import Sequence
    from absl import app
    
    # Default env vars that run on all TPU VMs.
    MACHINE_ENV_VARS = {
        "ENABLE_PJRT_COMPATIBILITY": "true",
        "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true",
        "TPU_SLICE_BUILDER_DUMP_ICI": "true",
        "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",  # Dumps HLOs for debugging
    }
    
    def setup_loggers():
        """Sets up loggers for Ray."""
        logging.basicConfig(level=logging.INFO)
    
    @ray_tpu.remote(
        topology={"v4-16": 1},
    )
    def run_maxtext_train(argv: Sequence[str]):
        maxtext_main(argv=argv)
    
    def main(argv: Sequence[str]):
        ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers))
    
        logging.info(f"argv: {argv}")
    
        try:
            ray.get(run_maxtext_train(argv=argv))
        except Exception as e:
            logging.error("Caught error during training: %s", e)
            logging.error("Shutting down...")
            ray.shutdown()
            raise e
    
        logging.info("Training complete!")
        ray.shutdown()
    
    if __name__ == "__main__":
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        app.run(main)
    
  8. Para executar o script, execute o seguinte comando:

        python maxtext/MaxText/ray_trainer.py maxtext/MaxText/configs/base.yml \
            base_output_directory=/tmp/maxtext \
            dataset_type=synthetic \
            per_device_batch_size=2 \
            max_target_length=8192 \
            model_name=default \
            steps=100 \
            run_name=test
    

    A saída será assim:

    (run_maxtext_train pid=78967, ip=10.130.0.11) Started an asynchronous checkpoint save for step 0
    (run_maxtext_train pid=78967, ip=10.130.0.11)
    (run_maxtext_train pid=78967, ip=10.130.0.11) Memstats: After params initialized:
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_4(process=1,(0,0,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_5(process=1,(1,0,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_6(process=1,(0,1,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_7(process=1,(1,1,1,0))
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 0, seconds: 11.775, TFLOP/s/device: 13.153, Tokens/s/device: 1391.395, total_weights: 131072, loss: 12.066
    (run_maxtext_train pid=80538, ip=10.130.0.12)
    (run_maxtext_train pid=80538, ip=10.130.0.12) To see full metrics 'tensorboard --logdir=/tmp/maxtext/test/tensorboard/'
    (run_maxtext_train pid=80538, ip=10.130.0.12) Waiting for step 0 to finish before checkpoint...
    (run_maxtext_train pid=80538, ip=10.130.0.12) Waited 0.7087039947509766 seconds for step 0 to finish before starting checkpointing.
    (run_maxtext_train pid=80538, ip=10.130.0.12) Started an asynchronous checkpoint save for step 0
    (run_maxtext_train pid=80538, ip=10.130.0.12) Memstats: After params initialized:
    (run_maxtext_train pid=80538, ip=10.130.0.12)   Using (GB) 1.59 / 30.75 (5.170732%) on TPU_3(process=0,(1,1,0,0)) [repeated 4x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 4, seconds: 1.116, TFLOP/s/device: 138.799, Tokens/s/device: 14683.240, total_weights: 131072, loss: 0.000 [repeated 9x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 9, seconds: 1.068, TFLOP/s/device: 145.065, Tokens/s/device: 15346.083, total_weights: 131072, loss: 0.000 [repeated 9x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 14, seconds: 1.116, TFLOP/s/device: 138.754, Tokens/s/device: 14678.439, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    
    ...
    
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 89, seconds: 1.116, TFLOP/s/device: 138.760, Tokens/s/device: 14679.083, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 94, seconds: 1.091, TFLOP/s/device: 141.924, Tokens/s/device: 15013.837, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 99, seconds: 1.116, TFLOP/s/device: 138.763, Tokens/s/device: 14679.412, total_weights: 131072, loss: 0.000 [repeated 10x across cluster]
    (run_maxtext_train pid=80538, ip=10.130.0.12) Output size: 1657041920, temp size: 4907988480, argument size: 1657366016, host temp size: 0, in bytes.
    I0121 01:39:46.830807 130655182204928 ray_trainer.py:47] Training complete!
    (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 99, seconds: 1.191, TFLOP/s/device: 130.014, Tokens/s/device: 13753.874, total_weights: 131072, loss: 0.000
    

Recursos de TPU e Ray

O Ray trata as TPUs de maneira diferente das GPUs para acomodar a diferença no uso. No exemplo a seguir, há nove nós Ray no total:

  • O nó principal do Ray está em execução em uma VM n1-standard-16.
  • Os nós de trabalho do Ray estão sendo executados em duas TPUs v6e-16. Cada TPU constitui quatro workers.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
 1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
 1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
 1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
 1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
 1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
 1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
 1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
 1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
 1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/727.0 CPU
 0.0/32.0 TPU
 0.0/2.0 TPU-v6e-16-head
 0B/5.13TiB memory
 0B/1.47TiB object_store_memory
 0.0/4.0 tpu-group-0
 0.0/4.0 tpu-group-1

Demands:
 (no resource demands)

Descrições dos campos de uso de recursos:

  • CPU: o número total de CPUs disponíveis no cluster.
  • TPU: o número de chips de TPU no cluster.
  • TPU-v6e-16-head: um identificador especial para o recurso que corresponde ao worker 0 de uma fração de TPU. Isso é importante para acessar fatias de TPU individuais.
  • memory: memória heap do worker usada pelo aplicativo.
  • object_store_memory: memória usada quando o aplicativo cria objetos no armazenamento de objetos usando ray.put e quando ele retorna valores de funções remotas.
  • tpu-group-0 e tpu-group-1: identificadores exclusivos para os segmentos de TPU individuais. Isso é importante para executar trabalhos em fatias. Esses campos são definidos como 4 porque há 4 hosts por fatia de TPU em uma v6e-16.