Scala i carichi di lavoro ML con Ray

Introduzione

Lo strumento Ray Cloud TPU combina l'API Cloud TPU e Ray Job con l'obiettivo di migliorare l'esperienza di sviluppo degli utenti su Cloud TPU. Questa guida dell'utente fornisce un esempio minimo di come utilizzare Ray con le Cloud TPU. Questi esempi non sono destinati a essere utilizzati nei servizi di produzione e sono solo a scopo illustrativo.

Che cosa è incluso in questo strumento?

Per praticità, lo strumento fornisce:

  • Astrazioni generiche che nascondono il boilerplate per le azioni TPU comuni
  • Esempi di giocattoli che puoi creare con un fork per i tuoi flussi di lavoro di base

In particolare:

  • tpu_api.py: wrapper Python per le operazioni TPU di base che utilizzano l'API Cloud TPU.
  • tpu_controller.py: rappresentazione della classe di una TPU. Si tratta essenzialmente di un wrapper per tpu_api.py.
  • ray_tpu_controller.py: controller TPU con funzionalità Ray. Elimina il boilerplate per i cluster Ray e i Ray Job.
  • run_basic_jax.py: Esempio di base che mostra come utilizzare RayTpuController per print(jax.device_count()).
  • run_hp_search.py: esempio di base che mostra come Ray Tune può essere utilizzato con JAX/Flax su MNIST.
  • run_pax_autoresume.py: Esempio che mostra come utilizzare RayTpuController per l'addestramento a tolleranza di errore utilizzando PAX come carico di lavoro di esempio.

Configurazione del nodo head del cluster Ray

Uno dei modi di base per utilizzare Ray con un pod di TPU è configurare il pod TPU come cluster Ray. Creare una VM CPU separata come VM coordinatore è il modo naturale per farlo. Il seguente grafico mostra un esempio di configurazione di un cluster Ray:

Esempio di configurazione di un cluster Ray

I seguenti comandi mostrano come configurare un cluster Ray utilizzando Google Cloud CLI:

$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ...
$ gcloud compute ssh my_tpu_admin

$ (vm) pip3 install ray[default]
$ (vm) ray start --head --port=6379 --num-cpus=0
...
# (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP)
$ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"

Per praticità, forniamo anche script di base per la creazione di una VM coordinatore e il deployment dei contenuti di questa cartella nella VM coordinatore. Per il codice sorgente, consulta create_cpu.sh e deploy.sh.

Questi script impostano alcuni valori predefiniti:

  • create_cpu.sh creerà una VM denominata $USER-admin e utilizzerà qualsiasi progetto e zona impostati sulle impostazioni predefinite di gcloud config. Esegui gcloud config list per visualizzare queste impostazioni predefinite.
  • create_cpu.sh alloca per impostazione predefinita una dimensione del disco di avvio di 200 GB.
  • deploy.sh presuppone che il nome della VM sia $USER-admin. Se modifichi questo valore in create_cpu.sh, assicurati di modificarlo in deploy.sh.

Per utilizzare gli script di convenienza:

  1. Clona il repository GitHub nella tua macchina locale e inserisci la cartella ray_tpu:

    $ git clone https://github.com/tensorflow/tpu.git
    $ cd tpu/tools/ray_tpu/
    
  2. Se non hai un account di servizio dedicato per l'amministrazione delle TPU (vivamente consigliato), configurane uno:

    $ ./create_tpu_service_account.sh
    
  3. Crea una VM coordinatore:

    $ ./create_cpu.sh
    

    Questo script installa le dipendenze sulla VM utilizzando uno script di avvio e blocca automaticamente fino al completamento dello script di avvio.

  4. Esegui il deployment del codice locale nella VM coordinatore:

    $ ./deploy.sh
    
  5. SSH alla VM:

    $ gcloud compute ssh $USER-admin -- -L8265:localhost:8265
    

    Il port forwarding è abilitato qui, in quanto Ray avvierà automaticamente una dashboard sulla porta 8265. Dalla macchina che utilizzi SSH alla VM coordinatore, potrai accedere a questa dashboard all'indirizzo http://127.0.0.1:8265/.

  6. Se hai saltato il passaggio 0, configura le tue credenziali gcloud all'interno della VM CPU:

    $ (vm) gcloud auth login --update-adc
    

    Questo passaggio imposta le informazioni sull'ID progetto e consente all'API Cloud TPU di essere eseguita sulla VM coordinatore.

  7. Requisiti di installazione:

    $ (vm) pip3 install -r src/requirements.txt
    
  8. Avvia Ray sulla VM coordinatore e la VM coordinatore diventa il nodo head del cluster Ray:

    $ (vm) ray start --head --port=6379 --num-cpus=0
    

Esempi di utilizzo

Esempio di JAX di base

run_basic_jax.py è un esempio minimo che mostra come utilizzare l'ambiente di runtime Ray Jobs e Ray su un cluster Ray con VM TPU per eseguire un carico di lavoro JAX.

Per i framework ML compatibili con le Cloud TPU che utilizzano un modello di programmazione multi-controller, come JAX e PyTorch/XLA PJRT, devi eseguire almeno un processo per host. Per ulteriori informazioni, consulta Modello di programmazione multi-processo. In pratica, potrebbe avere il seguente aspetto:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"

Se hai più di 16 host, ad esempio v4-128, riscontrerai problemi di scalabilità SSH e il tuo comando potrebbe dover passare a:

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8

Questo può diventare un ostacolo sulla velocità di sviluppo se my_bug_free_python_code contiene bug. Uno dei modi per risolvere questo problema è utilizzare un orchestratore come Kubernetes o Ray. Ray include il concetto di ambiente di runtime che, una volta applicato, esegue il deployment di codice e dipendenze quando viene eseguita l'applicazione Ray.

La combinazione dell'ambiente di runtime Ray con il cluster Ray e i job Ray ti consente di bypassare il ciclo SCP/SSH. Supponendo che tu abbia seguito gli esempi precedenti, puoi eseguire questa operazione con:

$ python3 legacy/run_basic_jax.py

L'output è simile al seguente:

2023-03-01 22:12:10,065   INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379...
2023-03-01 22:12:10,072   INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265
W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu...
Creating TPU:  $USER-ray-test
Request:  {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}}
Create TPU operation still running...
...
Create TPU operation complete.
I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster...
…
I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host.
2023-03-01 22:15:18,010   INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip.
2023-03-01 22:15:18,010   INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'.
2023-03-01 22:15:18,080   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload.
I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs.
...
I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED
I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089   21286 credentials_generic.cc:35]            Could not get HOME environment variable.
8

I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED
I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499    8561 credentials_generic.cc:35]            Could not get HOME environment variable.
8

Addestramento sulla tolleranza agli errori

Questo esempio mostra come utilizzare RayTpuController per implementare l'addestramento a tolleranza di errore. Per questo esempio, abbiamo preaddestrato un semplice LLM su PAX su una versione 4-16, ma tieni presente che puoi sostituire questo carico di lavoro PAX con qualsiasi altro carico di lavoro a lunga esecuzione. Per il codice sorgente, consulta run_pax_autoresume.py.

Per eseguire questo esempio:

  1. Clona paxml nella VM coordinatore:

    $ git clone https://github.com/google/paxml.git
    

    Per dimostrare la facilità d'uso offerta da Ray Runtime Environment per creare e implementare le modifiche JAX, questo esempio richiede la modifica di PAX.

  2. Aggiungi una nuova configurazione per l'esperimento:

    $ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py
    
    @experiment_registry.register
    class TestModel(LmCloudSpmd2BLimitSteps):
    ICI_MESH_SHAPE = [1, 4, 2]
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ
    
    def task(self) -> tasks_lib.SingleTask.HParams:
      task_p = super().task()
      task_p.train.num_train_steps = 1000
      task_p.train.save_interval_steps = 100
      return task_p
    EOT
    
  3. Esegui run_pax_autoresume.py:

    $ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
    
  4. Durante l'esecuzione del carico di lavoro, scopri cosa succede quando elimini la TPU, denominata $USER-tpu-ray per impostazione predefinita:

    $ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b
    

    Ray rileverà che la TPU non è disponibile con il seguente messaggio:

    I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata.
    W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu...
    2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    

    Il job ricrea automaticamente la VM TPU e riavvia il job di addestramento in modo che possa riprendere l'addestramento dal checkpoint più recente (200 passaggi in questo esempio):

    I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting...
    I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`...
    I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`...
    I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally
    

Questo esempio mostra l'utilizzo di Ray Tune, di Ray AIR, della sintonia degli iperparametri MNIST di JAX/FLAX. Per il codice sorgente, consulta run_hp_search.py.

Per eseguire questo esempio:

  1. Installa i requisiti:

    $ pip3 install -r src/tune/requirements.txt
    
  2. Esegui run_hp_search.py:

    $ python3 src/tune/run_hp_search.py
    

    L'output è simile al seguente:

    Number of trials: 3/3 (3 TERMINATED)
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    | Trial name                  | status     | loc               |   learning_rate |   momentum |    acc |   iter |   total time (s) |
    |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------|
    | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 |     1.15258e-09 |   0.897988 | 0.0982 |      3 |          82.4525 |
    | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 |     0.000219523 |   0.825463 | 0.1009 |      3 |          73.1168 |
    | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 |     1.08035e-08 |   0.660416 | 0.098  |      3 |          71.6813 |
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    
    2023-03-02 21:50:47,378   INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop).
    ...
    

Risoluzione dei problemi

Impossibile connettere il nodo Ray head

Se esegui un carico di lavoro che crea/elimina il ciclo di vita delle TPU, a volte questo non disconnette gli host TPU dal cluster Ray. Potrebbe essere visualizzato come errori gRPC, che indicano che il nodo Ray head non è in grado di connettersi a un insieme di indirizzi IP.

Di conseguenza, potrebbe essere necessario terminare la sessione ray (ray stop) e riavviarla (ray start --head --port=6379 --num-cpus=0).

Il job Ray non riesce direttamente senza alcun output di log

PAX è sperimentale e questo esempio potrebbe non funzionare a causa di dipendenze da pip. In tal caso, potresti visualizzare qualcosa di simile al seguente:

I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs.
I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures.
I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2})
I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED.
2023-03-03 20:51:38,641   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
2023-03-03 20:51:38,706   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.

Per conoscere la causa principale dell'errore, puoi andare all'indirizzo http://127.0.0.1:8265/ e visualizzare la dashboard per i job in esecuzione/non riusciti, che fornirà ulteriori informazioni. runtime_env_agent.log mostra tutte le informazioni di errore relative alla configurazione di runtime_env, ad esempio:

60    INFO: pip is looking at multiple versions of  to determine which version is compatible with other requirements. This could take a while.
61    INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while.
62    ERROR: Cannot install paxml because these package versions have conflicting dependencies.
63
64    The conflict is caused by:
65        praxis 0.3.0 depends on t5x
66        praxis 0.2.1 depends on t5x
67        praxis 0.2.0 depends on t5x
68        praxis 0.1 depends on t5x
69
70    To fix this you could try to:
71    1. loosen the range of package versions you've specified
72    2. remove package versions to allow pip attempt to solve the dependency conflict
73
74    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts