ML-Arbeitslasten mit Ray skalieren
Einleitung
Das Cloud TPU Ray-Tool kombiniert die Cloud TPU API und Ray-Jobs um die Fähigkeiten der Nutzenden Entwicklungsumgebung auf Cloud TPU. Dieses enthält ein kurzes Beispiel dafür, wie Sie Ray Cloud TPUs Diese Beispiele sind nicht für die Verwendung in Produktionsdiensten gedacht und dienen nur der Veranschaulichung.
Was beinhaltet dieses Tool?
Das Tool bietet Ihnen folgende Möglichkeiten:
- Allgemeine Abstraktionsschichten, die Boilerplate-Code für gängige TPU-Aktionen ausblenden
- Spielzeugbeispiele, die Sie für Ihre eigenen grundlegenden Workflows abspalten können
Zum Beispiel:
tpu_api.py
: Python-Wrapper für einfache TPU-Vorgänge mit der Cloud TPU APItpu_controller.py
: Klassenrepräsentation einer TPU. Dies ist im Wesentlichen ein Wrapper fürtpu_api.py
.ray_tpu_controller.py
: TPU-Controller mit Ray-Funktionalität Dadurch wird die Boilerplate für Ray-Cluster und Ray-Jobs abstrahiert.run_basic_jax.py
: rudimentäres Beispiel, das zeigt, wieRayTpuController
fürprint(jax.device_count())
verwendet wirdrun_hp_search.py
: Einfaches Beispiel, das zeigt, wie Ray Tune mit JAX/Flax auf MNIST verwendet werden kann.run_pax_autoresume.py
: Beispiel für die Verwendung vonRayTpuController
für Fehler tolerantes Training mit PAX als Beispielarbeitslast.
Ray-Cluster-Head-Knoten einrichten
Eine der grundlegenden Möglichkeiten zur Verwendung von Ray mit einem TPU-Pod besteht darin, den TPU-Pod einzurichten. als Ray-Cluster. Das Erstellen einer separaten CPU-VM als Koordinator-VM wie das geht. Die folgende Grafik zeigt ein Beispiel für eine Ray-Clusterkonfiguration:
Die folgenden Befehle zeigen, wie Sie einen Ray-Cluster mit der Google Cloud CLI einrichten können:
$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ... $ gcloud compute ssh my_tpu_admin $ (vm) pip3 install ray[default] $ (vm) ray start --head --port=6379 --num-cpus=0 ... # (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP) $ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"
Zur Vereinfachung stellen wir auch einfache Scripts zum Erstellen einer Koordinator-VM und zum Bereitstellen des Inhalts dieses Ordners auf der Koordinator-VM bereit. Quellcode: create_cpu.sh
und deploy.sh
In diesen Scripts werden einige Standardwerte festgelegt:
create_cpu.sh
erstellt eine VM mit dem Namen$USER-admin
und wird das Projekt und die Zone zu nutzen, die auf Ihregcloud config
-Standardeinstellungen eingestellt sind. Führen Siegcloud config list
aus, um diese Standardwerte zu sehen.create_cpu.sh
weist standardmäßig ein Bootlaufwerk von 200 GB zu.deploy.sh
geht davon aus, dass Ihr VM-Name$USER-admin
ist. Wenn Sie ändere diesen Wert increate_cpu.sh
und ändere ihn indeploy.sh
.
So verwenden Sie die Convenience-Skripts:
Klonen Sie das GitHub-Repository auf Ihren lokalen Computer und geben Sie den Ordner „
ray_tpu
“:$ git clone https://github.com/tensorflow/tpu.git $ cd tpu/tools/ray_tpu/
Wenn Sie kein spezielles Dienstkonto für die TPU-Verwaltung haben (empfohlen), richten Sie eines ein:
$ ./create_tpu_service_account.sh
Koordinator-VM erstellen:
$ ./create_cpu.sh
Dieses Script installiert Abhängigkeiten auf der VM mithilfe eines Startscripts und blockiert automatisch, bis das Startscript abgeschlossen ist.
Stellen Sie lokalen Code auf der Koordinator-VM bereit:
$ ./deploy.sh
Stellen Sie eine SSH-Verbindung zur VM her:
$ gcloud compute ssh $USER-admin -- -L8265:localhost:8265
Die Portweiterleitung ist hier aktiviert, da Ray automatisch ein Dashboard unter Port 8265 startet. Über den Computer, mit dem Sie eine SSH-Verbindung zur Koordinator-VM herstellen, können Sie unter http://127.0.0.1:8265/ auf dieses Dashboard zugreifen.
Wenn Sie Schritt 0 übersprungen haben, richten Sie Ihre gcloud-Anmeldedaten in der CPU-VM ein:
$ (vm) gcloud auth login --update-adc
In diesem Schritt werden Informationen zur Projekt-ID festgelegt und die Cloud TPU API kann auf der VM des Koordinators ausgeführt werden.
Installationsvoraussetzungen:
$ (vm) pip3 install -r src/requirements.txt
Wenn Sie Ray auf der Koordinator-VM starten, wird die Koordinator-VM zur Hauptknoten des Ray-Clusters:
$ (vm) ray start --head --port=6379 --num-cpus=0
Beispiele für die Verwendung
Einfaches JAX-Beispiel
run_basic_jax.py
ist ein minimales Beispiel, das zeigt, wie Sie Ray Jobs und Ray
Laufzeitumgebung in einem Ray-Cluster mit TPU-VMs zum Ausführen einer JAX-Arbeitslast
Bei mit Cloud TPUs kompatiblen ML-Frameworks, die ein Programmiermodell mit mehreren Controllern verwenden, z. B. JAX und PyTorch/XLA PJRT, müssen Sie mindestens einen Prozess pro Host ausführen. Weitere Informationen finden Sie unter Mehrfachprozess-Programmiermodell. In der Praxis könnte das so aussehen:
$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all $ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"
Wenn Sie mehr als etwa 16 Hosts haben, z. B. eine v4-128, treten Probleme mit der SSH-Skalierung auf und der Befehl muss möglicherweise in Folgendes geändert werden:
$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8 $ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8
Dies kann die Entwicklungsgeschwindigkeit beeinträchtigen, wenn my_bug_free_python_code
enthält Insekten. Eine Möglichkeit, dieses Problem zu lösen, besteht darin, einen Orchestrator wie Kubernetes oder Ray zu verwenden. Ray beinhaltet das Konzept eines
Laufzeitumgebung
die bei Anwendung Code und Abhängigkeiten bereitstellt, wenn die Ray-Anwendung
ausführen.
Wenn Sie die Ray-Laufzeitumgebung mit Ray-Cluster und Ray-Jobs kombinieren, können Sie SCP/SSH-Zyklus. Angenommen, Sie haben die obigen Beispiele befolgt, können Sie diesen Befehl mit folgendem Code ausführen:
$ python3 legacy/run_basic_jax.py
Die Ausgabe sieht in etwa so aus:
2023-03-01 22:12:10,065 INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379... 2023-03-01 22:12:10,072 INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265 W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu... Creating TPU: $USER-ray-test Request: {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}} Create TPU operation still running... ... Create TPU operation complete. I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster... … I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host. 2023-03-01 22:15:18,010 INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip. 2023-03-01 22:15:18,010 INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'. 2023-03-01 22:15:18,080 INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload. I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs. ... I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089 21286 credentials_generic.cc:35] Could not get HOME environment variable. 8 I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499 8561 credentials_generic.cc:35] Could not get HOME environment variable. 8
Fehlertolerantes Training
Dieses Beispiel zeigt, wie Sie mit RayTpuController
Fehler implementieren können
tolerantes Training. In diesem Beispiel trainieren wir eine einfache LLM auf PAX auf einer v4-16 vor. Sie können diese PAX-Arbeitslast jedoch durch eine andere lang laufende Arbeitslast ersetzen. Den Quellcode finden Sie unter
run_pax_autoresume.py
So führen Sie das Beispiel aus:
Klonen Sie
paxml
auf Ihre Koordinator-VM:$ git clone https://github.com/google/paxml.git
Um die Benutzerfreundlichkeit der Ray Runtime Environment für das Anwenden und Bereitstellen von JAX-Änderungen zu demonstrieren, müssen Sie in diesem Beispiel PAX ändern.
Neue Testkonfiguration hinzufügen:
$ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py @experiment_registry.register class TestModel(LmCloudSpmd2BLimitSteps): ICI_MESH_SHAPE = [1, 4, 2] CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ def task(self) -> tasks_lib.SingleTask.HParams: task_p = super().task() task_p.train.num_train_steps = 1000 task_p.train.save_interval_steps = 100 return task_p EOT
Führen Sie
run_pax_autoresume.py
aus.$ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
Experimentieren Sie während der Ausführung der Arbeitslast, was passiert, wenn Sie Ihre TPU löschen, mit dem Namen
$USER-tpu-ray
:$ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b
Ray erkennt, dass die TPU ausgefallen ist, und zeigt die folgende Meldung an:
I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata. W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu... 2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a (1) raylet crashes unexpectedly (OOM, preempted node, etc.) (2) raylet has lagging heartbeats due to slow network or busy workload. 2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a (1) raylet crashes unexpectedly (OOM, preempted node, etc.) (2) raylet has lagging heartbeats due to slow network or busy workload.
Der Job erstellt dann automatisch die TPU-VM neu und startet den Trainingsjob neu, damit das Training ab dem letzten Checkpoint fortgesetzt werden kann (Schritt 200 in diesem Beispiel):
I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting... I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`... I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`... I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally
HyperParameter-Suche
In diesem Beispiel wird gezeigt, wie Sie Ray Tune aus der Ray AIR verwenden, um die Hyperparameter von MNIST aus JAX/FLAX zu optimieren. Den Quellcode finden Sie unter
run_hp_search.py
So führen Sie das Beispiel aus:
Installieren Sie die erforderlichen Komponenten:
$ pip3 install -r src/tune/requirements.txt
Führen Sie
run_hp_search.py
aus.$ python3 src/tune/run_hp_search.py
Die Ausgabe sieht in etwa so aus:
Number of trials: 3/3 (3 TERMINATED) +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+ | Trial name | status | loc | learning_rate | momentum | acc | iter | total time (s) | |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------| | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 | 1.15258e-09 | 0.897988 | 0.0982 | 3 | 82.4525 | | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 | 0.000219523 | 0.825463 | 0.1009 | 3 | 73.1168 | | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 | 1.08035e-08 | 0.660416 | 0.098 | 3 | 71.6813 | +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+ 2023-03-02 21:50:47,378 INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop). ...
Fehlerbehebung
Ray-Head-Knoten kann keine Verbindung herstellen
Wenn Sie eine Arbeitslast ausführen, die den TPU-Lebenszyklus erstellt oder löscht, werden die TPU-Hosts manchmal nicht vom Ray-Cluster getrennt. Dies kann als gRPC-Fehler angezeigt werden, der signalisiert, dass der Ray-Leitknoten keine Verbindung zu einer Reihe von IP-Adressen herstellen kann.
Daher müssen Sie möglicherweise Ihre Ray-Sitzung beenden (ray stop
) und neu starten (ray start --head --port=6379 --num-cpus=0
).
Ray-Job schlägt direkt ohne Protokollausgabe fehl
PAX befindet sich in der Testphase und dieses Beispiel funktioniert möglicherweise nicht aufgrund von Pip-Abhängigkeiten. In diesem Fall sehen Sie möglicherweise Folgendes:
I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs. I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures. I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2}) I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED. 2023-03-03 20:51:38,641 INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload. 2023-03-03 20:51:38,706 INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
Die Ursache des Fehlers finden Sie unter http://127.0.0.1:8265/ im Dashboard für laufende und fehlgeschlagene Jobs. runtime_env_agent.log
zeigt alle Fehlerinformationen zur Einrichtung von runtime_env an, z. B.:
60 INFO: pip is looking at multiple versions ofto determine which version is compatible with other requirements. This could take a while. 61 INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while. 62 ERROR: Cannot install paxml because these package versions have conflicting dependencies. 63 64 The conflict is caused by: 65 praxis 0.3.0 depends on t5x 66 praxis 0.2.1 depends on t5x 67 praxis 0.2.0 depends on t5x 68 praxis 0.1 depends on t5x 69 70 To fix this you could try to: 71 1. loosen the range of package versions you've specified 72 2. remove package versions to allow pip attempt to solve the dependency conflict 73 74 ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts