Escalonar cargas de trabalho de ML usando o Ray

Introdução

A ferramenta Cloud TPU Ray combina a API Cloud TPU e Jobs do Ray com o objetivo de melhorar experiência de desenvolvimento no Cloud TPU. Isso guia do usuário fornece um exemplo mínimo de como você pode usar o Ray com e Cloud TPUs. Estes exemplos não devem ser usados na produção e têm fins exclusivamente ilustrativos.

O que está incluído nessa ferramenta?

Para sua conveniência, a ferramenta fornece:

  • Abstrações genéricas que ocultam o código boilerplate para ações comuns da TPU
  • Exemplos práticos que você pode bifurcar para seus fluxos de trabalho básicos

Especificamente:

  • tpu_api.py: Wrapper do Python para operações básicas de TPU usando a API Cloud TPU.
  • tpu_controller.py: Representação de classe de uma TPU. Essencialmente, esse é um wrapper para tpu_api.py.
  • ray_tpu_controller.py: Controlador de TPU com funcionalidade Ray. Isso abstrai o modelo para clusters e jobs do Ray.
  • run_basic_jax.py: Exemplo básico que mostra como usar RayTpuController para print(jax.device_count()).
  • run_hp_search.py: Exemplo básico que mostra como o Ray Tune pode ser usado com JAX/Flax no MNIST.
  • run_pax_autoresume.py: Exemplo que mostra como usar RayTpuController para falhas usando PAX como carga de trabalho de exemplo.

Como configurar o nó principal do cluster do Ray

Uma das maneiras básicas de usar o Ray com um pod de TPU é configurar o pod de TPU como um cluster do Ray. Criar uma VM de CPU separada como VM de coordenador é a maneira natural de fazer isso. O gráfico a seguir mostra um exemplo de configuração de cluster do Ray:

Exemplo de configuração de cluster do Ray

Os comandos a seguir mostram como configurar um cluster do Ray usando a Google Cloud CLI:

$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ...
$ gcloud compute ssh my_tpu_admin

$ (vm) pip3 install ray[default]
$ (vm) ray start --head --port=6379 --num-cpus=0
...
# (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP)
$ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"

Para sua conveniência, também disponibilizamos scripts básicos para a criação de um coordenador VM e implantar o conteúdo dessa pasta na VM do coordenador. Para obter o código-fonte, consulte create_cpu.sh e deploy.sh

Esses scripts definem alguns valores padrão:

  • create_cpu.sh criará uma VM chamada $USER-admin e use qualquer projeto e zona definidos com os padrões de gcloud config. Execute gcloud config list para conferir esses padrões.
  • create_cpu.sh por padrão aloca um tamanho de disco de inicialização de 200 GB.
  • deploy.sh pressupõe que o nome da VM é $USER-admin. Se você mude esse valor em create_cpu.sh, mude-o em deploy.sh.

Para usar os scripts de conveniência:

  1. Clone o repositório do GitHub na sua máquina local e acesse a pasta ray_tpu:

    $ git clone https://github.com/tensorflow/tpu.git
    $ cd tpu/tools/ray_tpu/
  2. Se você não tiver uma conta de serviço dedicada para a administração de TPU, configure uma:

    $ ./create_tpu_service_account.sh
  3. Crie uma VM de coordenador:

    $ ./create_cpu.sh

    Esse script instala dependências na VM usando um script de inicialização e bloqueia automaticamente até que o script de inicialização seja concluído.

  4. Implante o código local na VM do coordenador:

    $ ./deploy.sh
  5. SSH para a VM:

    $ gcloud compute ssh $USER-admin -- -L8265:localhost:8265

    O encaminhamento de portas está ativado aqui porque o Ray iniciará automaticamente um painel em na porta 8265. Acesse a máquina que você conectou por SSH à VM de coordenação. acessar o painel em http://127.0.0.1:8265/.

  6. Se você tiver pulado a etapa 0, configure suas credenciais do gcloud na VM de CPU:

    $ (vm) gcloud auth login --update-adc

    Esta etapa define as informações do ID do projeto e permite que a API Cloud TPU seja executada na VM do coordenador.

  7. Requisitos de instalação:

    $ (vm) pip3 install -r src/requirements.txt
  8. Inicie o Ray na VM de coordenador, e a VM de coordenador se tornará a nó principal do cluster do Ray:

    $ (vm) ray start --head --port=6379 --num-cpus=0

Exemplos de uso

Exemplo básico de JAX

run_basic_jax.py é um exemplo mínimo que demonstra como usar os jobs do Ray e o ambiente de execução do Ray em um cluster do Ray com VMs TPU para executar uma carga de trabalho do JAX.

Para frameworks de ML compatíveis com Cloud TPUs que usam uma de programação com vários controles, como JAX e PyTorch/XLA PJRT, é preciso executar pelo menos um processo por host. Para mais informações, consulte Modelo de programação de vários processos. Na prática, o resultado será semelhante a este:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"

Se você tiver mais de 16 hosts, como um v4-128, vai encontrar problemas de escalonamento do SSH, e seu comando poderá mudar para:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8

Isso pode atrapalhar a velocidade do desenvolvedor se my_bug_free_python_code contém bugs. Uma das maneiras de resolver esse problema é usando um orquestrador, como o Kubernetes ou o Ray. O Ray inclui o conceito de um ambiente de execução que, quando aplicado, implanta código e dependências quando o aplicativo Ray é executado.

A combinação do ambiente de execução do Ray com o cluster e os jobs do Ray permite que você ignore o ciclo SCP/SSH. Supondo que você tenha seguido os exemplos acima, isso pode ser feito com:

$ python3 legacy/run_basic_jax.py

O resultado será assim:

2023-03-01 22:12:10,065   INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379...
2023-03-01 22:12:10,072   INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265
W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu...
Creating TPU:  $USER-ray-test
Request:  {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}}
Create TPU operation still running...
...
Create TPU operation complete.
I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster...

I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host.
2023-03-01 22:15:18,010   INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip.
2023-03-01 22:15:18,010   INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'.
2023-03-01 22:15:18,080   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload.
I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs.
...
I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED
I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089   21286 credentials_generic.cc:35]            Could not get HOME environment variable.
8

I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED
I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499    8561 credentials_generic.cc:35]            Could not get HOME environment variable.
8

Treinamento tolerante a falhas

Este exemplo mostra como usar RayTpuController para implementar o treinamento tolerante a falhas. Neste exemplo, pré-treinamos um LLM simples PAX (em inglês) em um sistema operacional v4-16. No entanto, é possível substituir essa carga de trabalho de PAX por qualquer outra carga carga de trabalho em execução. Para obter o código-fonte, consulte run_pax_autoresume.py

Para executar esse exemplo:

  1. Clone paxml para a VM do coordenador:

    $ git clone https://github.com/google/paxml.git

    Para demonstrar a facilidade de uso que o ambiente de execução do Ray oferece para fazer e implantar mudanças no JAX, este exemplo exige que você modifique o PAX.

  2. Adicione uma nova configuração de experimento:

    $ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py
    
    @experiment_registry.register
    class TestModel(LmCloudSpmd2BLimitSteps):
    ICI_MESH_SHAPE = [1, 4, 2]
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ
    
    def task(self) -> tasks_lib.SingleTask.HParams:
      task_p = super().task()
      task_p.train.num_train_steps = 1000
      task_p.train.save_interval_steps = 100
      return task_p
    EOT
    
  3. Execute run_pax_autoresume.py:

    $ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
  4. Enquanto a carga de trabalho é executada, teste o que acontece ao excluir a TPU. por padrão, chamado $USER-tpu-ray:

    $ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b

    O Ray detectará que a TPU está inativa e exibirá a seguinte mensagem:

    I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata.
    W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu...
    2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.

    O job recria automaticamente a VM TPU e reinicia o job de treinamento para retomar o treinamento do último ponto de verificação (etapa 200 neste exemplo):

    I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting...
    I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`...
    I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`...
    I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally

Este exemplo mostra como usar o Ray Tune do Ray AIR para ajustar os hiperparâmetros MNIST da JAX/FLAX. Para obter o código-fonte, consulte run_hp_search.py

Para executar esse exemplo:

  1. Instale os requisitos:

    $ pip3 install -r src/tune/requirements.txt
  2. Execute run_hp_search.py:

    $ python3 src/tune/run_hp_search.py

    O resultado será assim:

    Number of trials: 3/3 (3 TERMINATED)
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    | Trial name                  | status     | loc               |   learning_rate |   momentum |    acc |   iter |   total time (s) |
    |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------|
    | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 |     1.15258e-09 |   0.897988 | 0.0982 |      3 |          82.4525 |
    | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 |     0.000219523 |   0.825463 | 0.1009 |      3 |          73.1168 |
    | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 |     1.08035e-08 |   0.660416 | 0.098  |      3 |          71.6813 |
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    
    2023-03-02 21:50:47,378   INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop).
    ...

Solução de problemas

O nó principal do Ray não consegue se conectar

Se você executar uma carga de trabalho que cria/exclui o ciclo de vida da TPU, às vezes isso não desconecta os hosts da TPU do cluster do Ray. Isso pode aparecer como erros gRPC que indicam que o nó principal do Ray não consegue se conectar a um conjunto de endereços IP.

Por isso, talvez você precise encerrar sua sessão de raio (ray stop) e reiniciar o dispositivo. (ray start --head --port=6379 --num-cpus=0).

O job do Ray falha diretamente sem nenhuma saída de registro

O PAX é experimental, e este exemplo pode falhar devido a dependências do pip. Se isso acontecer, você verá algo assim:

I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs.
I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures.
I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2})
I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED.
2023-03-03 20:51:38,641   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
2023-03-03 20:51:38,706   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.

Para conferir a causa raiz do erro, acesse http://127.0.0.1:8265/ e abra o painel dos jobs em execução/falha, que vai fornecer mais informações. runtime_env_agent.log mostra todas as informações de erro relacionadas a configuração de runtime_env, por exemplo:

60    INFO: pip is looking at multiple versions of  to determine which version is compatible with other requirements. This could take a while.
61    INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while.
62    ERROR: Cannot install paxml because these package versions have conflicting dependencies.
63
64    The conflict is caused by:
65        praxis 0.3.0 depends on t5x
66        praxis 0.2.1 depends on t5x
67        praxis 0.2.0 depends on t5x
68        praxis 0.1 depends on t5x
69
70    To fix this you could try to:
71    1. loosen the range of package versions you've specified
72    2. remove package versions to allow pip attempt to solve the dependency conflict
73
74    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts