Escala las cargas de trabajo de AA con Ray

Introducción

La herramienta Ray de Cloud TPU combina la API de Cloud TPU y Ray Jobs con el objetivo de mejorar la experiencia de desarrollo de los usuarios en Cloud TPU. En esta guía del usuario, se proporciona un ejemplo mínimo de cómo puedes usar Ray con las Cloud TPU. Estos ejemplos no están destinados a usarse en servicios de producción y solo se usan con fines ilustrativos.

¿Qué incluye esta herramienta?

Para tu comodidad, la herramienta proporciona lo siguiente:

  • Abstracciones genéricas que ocultan el código estándar para acciones comunes de TPU
  • Ejemplos que puedes bifurcar en tus propios flujos de trabajo básicos

En particular, haz lo siguiente:

  • tpu_api.py: Wrapper de Python para operaciones básicas de TPU mediante la API de Cloud TPU.
  • tpu_controller.py: Representación de clase de una TPU. En esencia, es un wrapper para tpu_api.py.
  • ray_tpu_controller.py: Es el controlador de TPU con funcionalidad Ray. Esto simplifica el código estándar para Ray cluster y Ray Jobs.
  • run_basic_jax.py: Ejemplo básico que muestra cómo usar RayTpuController para print(jax.device_count())
  • run_hp_search.py: Ejemplo básico que muestra cómo se puede usar Ray Tune con JAX/Flax en MNIST
  • run_pax_autoresume.py: Ejemplo que muestra cómo puedes usar RayTpuController para el entrenamiento tolerante a errores con PAX como carga de trabajo de ejemplo.

Configura el nodo principal del clúster de Ray

Una de las formas básicas en que puedes usar Ray con un pod de TPU es configurarlo como un clúster de Ray. La forma natural de hacerlo es crear una VM de CPU separada como VM de coordinador. En el siguiente gráfico, se muestra un ejemplo de una configuración de un clúster de Ray:

Un ejemplo de configuración de un clúster de Ray

Los siguientes comandos muestran cómo puedes configurar un clúster de Ray con Google Cloud CLI:

$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ...
$ gcloud compute ssh my_tpu_admin

$ (vm) pip3 install ray[default]
$ (vm) ray start --head --port=6379 --num-cpus=0
...
# (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP)
$ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"

Para tu comodidad, también proporcionamos secuencias de comandos básicas a fin de crear una VM de coordinador y, luego, implementar el contenido de esta carpeta en tu VM de coordinación. Para ver el código fuente, consulta create_cpu.sh y deploy.sh.

Estas secuencias de comandos establecen algunos valores predeterminados:

  • create_cpu.sh creará una VM llamada $USER-admin y usará cualquier proyecto y zona que tengan los valores predeterminados de gcloud config. Ejecuta gcloud config list para ver esos valores predeterminados.
  • De forma predeterminada, create_cpu.sh asigna un tamaño de disco de arranque de 200 GB.
  • deploy.sh supone que el nombre de tu VM es $USER-admin. Si cambias ese valor en create_cpu.sh, asegúrate de hacerlo en deploy.sh.

Para usar las secuencias de comandos útiles, sigue estos pasos:

  1. Clona el repositorio de GitHub en tu máquina local y, luego, ingresa la carpeta ray_tpu:

    $ git clone https://github.com/tensorflow/tpu.git
    $ cd tpu/tools/ray_tpu/
    
  2. Si no tienes una cuenta de servicio dedicada para la administración de TPU (muy recomendado), configura una:

    $ ./create_tpu_service_account.sh
    
  3. Crea una VM de coordinación:

    $ ./create_cpu.sh
    

    Esta secuencia de comandos instala dependencias en la VM mediante una secuencia de comandos de inicio y se bloquea de forma automática hasta que se complete la secuencia de comandos de inicio.

  4. Implementa el código local en la VM del coordinador:

    $ ./deploy.sh
    
  5. Conéctate a la VM mediante SSH:

    $ gcloud compute ssh $USER-admin -- -L8265:localhost:8265
    

    La redirección de puertos se habilita aquí, ya que Ray iniciará automáticamente un panel en el puerto 8265. Desde la máquina a la que conectas mediante SSH a la VM del coordinador, podrás acceder a este panel en http://127.0.0.1:8265/.

  6. Si omitiste el paso 0, configura tus credenciales de gcloud en la VM de la CPU:

    $ (vm) gcloud auth login --update-adc
    

    En este paso, se establece la información del ID del proyecto y se permite que la API de Cloud TPU se ejecute en la VM de coordinación.

  7. Requisitos de instalación:

    $ (vm) pip3 install -r src/requirements.txt
    
  8. Inicia Ray en la VM de coordinador, y esta se convertirá en el nodo principal del clúster de Ray:

    $ (vm) ray start --head --port=6379 --num-cpus=0
    

Ejemplos de uso

Ejemplo básico de JAX

run_basic_jax.py es un ejemplo mínimo que demuestra cómo puedes usar el entorno de ejecución de Ray Jobs y Ray en un clúster de Ray con VM de TPU para ejecutar una carga de trabajo de JAX.

Para los frameworks de AA compatibles con Cloud TPU que usan un modelo de programación con varios controladores, como JAX y PyTorch/XLA PJRT, debes ejecutar al menos un proceso por host. Para obtener más información, consulta Modelo de programación de varios procesos. En la práctica, podría verse de la siguiente manera:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"

Si tienes más de 16 hosts, como v4-128, tendrás problemas de escalabilidad de SSH y es posible que tu comando deba cambiar a lo siguiente:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8

Esto puede convertirse en un obstáculo para la velocidad del desarrollador si my_bug_free_python_code contiene errores. Una de las formas en que puedes resolver este problema es mediante un organizador como Kubernetes o Ray. Ray incluye el concepto de un entorno de entorno de ejecución que, cuando se aplica, implementa código y dependencias cuando se ejecuta la aplicación de Ray.

La combinación del entorno de ejecución de Ray con el clúster de Ray y los trabajos de Ray te permite omitir el ciclo de SCP/SSH. Si sigues los ejemplos anteriores, puedes ejecutarlo con el siguiente comando:

$ python3 legacy/run_basic_jax.py

El resultado es similar al siguiente:

2023-03-01 22:12:10,065   INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379...
2023-03-01 22:12:10,072   INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265
W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu...
Creating TPU:  $USER-ray-test
Request:  {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}}
Create TPU operation still running...
...
Create TPU operation complete.
I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster...
…
I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host.
2023-03-01 22:15:18,010   INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip.
2023-03-01 22:15:18,010   INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'.
2023-03-01 22:15:18,080   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload.
I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs.
...
I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED
I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089   21286 credentials_generic.cc:35]            Could not get HOME environment variable.
8

I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED
I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499    8561 credentials_generic.cc:35]            Could not get HOME environment variable.
8

Entrenamiento tolerante a errores

En este ejemplo, se muestra cómo puedes usar RayTpuController para implementar el entrenamiento tolerante a errores. Para este ejemplo, entrenamos previamente un LLM simple en PAX en una versión v4-16, pero ten en cuenta que puedes reemplazar esta carga de trabajo PAX por cualquier otra carga de trabajo de larga duración. Para ver el código fuente, consulta run_pax_autoresume.py.

Para ejecutar este ejemplo, haz lo siguiente:

  1. Clona paxml a tu VM de coordinador:

    $ git clone https://github.com/google/paxml.git
    

    A fin de demostrar la facilidad de uso que el entorno de ejecución de Ray proporciona para realizar e implementar cambios de JAX, en este ejemplo se requiere que modifiques PAX.

  2. Agrega una nueva configuración del experimento:

    $ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py
    
    @experiment_registry.register
    class TestModel(LmCloudSpmd2BLimitSteps):
    ICI_MESH_SHAPE = [1, 4, 2]
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ
    
    def task(self) -> tasks_lib.SingleTask.HParams:
      task_p = super().task()
      task_p.train.num_train_steps = 1000
      task_p.train.save_interval_steps = 100
      return task_p
    EOT
    
  3. Ejecuta run_pax_autoresume.py:

    $ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
    
  4. A medida que se ejecuta la carga de trabajo, experimenta con lo que sucede cuando borras la TPU, que se llama $USER-tpu-ray de forma predeterminada:

    $ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b
    

    Ray detectará que la TPU está inactiva y mostrará el siguiente mensaje:

    I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata.
    W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu...
    2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    

    Y el trabajo volverá a crear la VM de TPU de forma automática y reiniciará el trabajo de entrenamiento para que pueda reanudar el entrenamiento desde el último punto de control (paso 200 en este ejemplo):

    I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting...
    I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`...
    I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`...
    I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally
    

En este ejemplo, se muestra el uso de Ray Tune desde Ray AIR para ajustar los hiperparámetros de MNIST desde JAX/FLAX. Para ver el código fuente, consulta run_hp_search.py.

Para ejecutar este ejemplo, haz lo siguiente:

  1. Instala los requisitos:

    $ pip3 install -r src/tune/requirements.txt
    
  2. Ejecuta run_hp_search.py:

    $ python3 src/tune/run_hp_search.py
    

    El resultado es similar al siguiente:

    Number of trials: 3/3 (3 TERMINATED)
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    | Trial name                  | status     | loc               |   learning_rate |   momentum |    acc |   iter |   total time (s) |
    |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------|
    | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 |     1.15258e-09 |   0.897988 | 0.0982 |      3 |          82.4525 |
    | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 |     0.000219523 |   0.825463 | 0.1009 |      3 |          73.1168 |
    | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 |     1.08035e-08 |   0.660416 | 0.098  |      3 |          71.6813 |
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    
    2023-03-02 21:50:47,378   INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop).
    ...
    

Soluciona problemas

No se puede conectar el nodo principal de Ray

Si ejecutas una carga de trabajo que crea o borra el ciclo de vida de la TPU, a veces esto no desconecta los hosts de TPU del clúster de Ray. Esto puede aparecer como errores de gRPC que indican que el nodo principal de Ray no puede conectarse a un conjunto de direcciones IP.

Como resultado, es posible que debas finalizar tu sesión de Ray (ray stop) y reiniciarla (ray start --head --port=6379 --num-cpus=0).

El trabajo de Ray falla directamente sin ningún resultado de registro

PAX es experimental y este ejemplo puede fallar debido a las dependencias de pip. Si eso sucede, es posible que veas algo como esto:

I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs.
I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures.
I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2})
I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED.
2023-03-03 20:51:38,641   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
2023-03-03 20:51:38,706   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.

Para ver la causa raíz del error, puedes ir a http://127.0.0.1:8265/ y ver el panel de los trabajos en ejecución o con errores, lo que proporcionará más información. runtime_env_agent.log muestra toda la información de errores relacionada con la configuración de runtime_env, por ejemplo:

60    INFO: pip is looking at multiple versions of  to determine which version is compatible with other requirements. This could take a while.
61    INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while.
62    ERROR: Cannot install paxml because these package versions have conflicting dependencies.
63
64    The conflict is caused by:
65        praxis 0.3.0 depends on t5x
66        praxis 0.2.1 depends on t5x
67        praxis 0.2.0 depends on t5x
68        praxis 0.1 depends on t5x
69
70    To fix this you could try to:
71    1. loosen the range of package versions you've specified
72    2. remove package versions to allow pip attempt to solve the dependency conflict
73
74    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts