Ray を使用して ML ワークロードをスケーリングする

はじめに

Cloud TPU Ray ツールは、Cloud TPU でのユーザーの開発エクスペリエンスを向上させることを目的として、Cloud TPU APIRay Jobs を組み合わせたものです。このユーザーガイドでは、Cloud TPU で Ray を使用する方法についての簡単な例を示します。これらの例は、本番サービスでの使用を意図したものではなく、あくまで説明のためのものです。

このツールの内容

利便性を高めるために、このツールには次の機能があります。

  • 一般的な TPU アクションのボイラープレートを隠す汎用的な抽象化
  • 独自の基本ワークフローにフォークできる簡単な例

具体的な内容は以下のとおりです。

  • tpu_api.py: Cloud TPU API を使用する基本的な TPU オペレーション用の Python ラッパー。
  • tpu_controller.py: TPU のクラス表現。これは基本的に tpu_api.py のラッパーです。
  • ray_tpu_controller.py: Ray 機能を備えた TPU コントローラ。これにより、Ray クラスタと Ray Jobs のボイラープレートが抽象化されます。
  • run_basic_jax.py: print(jax.device_count())RayTpuController を使用する方法を示す基本的な例。
  • run_hp_search.py: MNIST で JAX / Flax と Ray Tune を併用する方法を示す基本的な例。
  • run_pax_autoresume.py: PAX をサンプル ワークロードとして使用して、フォールト トレラントのトレーニングに RayTpuController を使用する方法を示す例。

Ray クラスタ ヘッドノードの設定

TPU Pod で Ray を使用する基本的な方法の一つは、TPU Pod を Ray クラスタとして設定することです。そのためには、コーディネーター VM として別の CPU VM を作成するのが自然な方法です。次の図は、Ray クラスタ構成の例を示しています。

Ray クラスタ構成の例

次のコマンドは、Google Cloud CLI を使用して Ray クラスタを設定する方法を示しています。

$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ...
$ gcloud compute ssh my_tpu_admin

$ (vm) pip3 install ray[default]
$ (vm) ray start --head --port=6379 --num-cpus=0
...
# (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP)
$ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"

参考までに、コーディネーター VM を作成して、このフォルダの内容をコーディネーター VM にデプロイするための基本的なスクリプトも用意されています。ソースコードについては、create_cpu.shdeploy.sh をご覧ください。

このスクリプトでは、以下のいくつかのデフォルト値が設定されます。

  • create_cpu.sh$USER-admin という名前の VM を作成し、gcloud config のデフォルトに設定されているプロジェクトとゾーンを利用します。gcloud config list を実行して、これらのデフォルトを確認します。
  • デフォルトでは、create_cpu.sh は 200 GB のブートディスクを割り当てます。
  • deploy.sh は、VM 名が $USER-admin であることを前提としています。create_cpu.sh でその値を変更する場合は、必ず deploy.sh でも変更してください。

便利なスクリプトを使用するには:

  1. GitHub リポジトリのクローンをローカルマシンに作成し、ray_tpu フォルダに移動します。

    $ git clone https://github.com/tensorflow/tpu.git
    $ cd tpu/tools/ray_tpu/
  2. TPU 管理専用のサービス アカウント(強く推奨)がない場合は、そのサービス アカウントを設定します。

    $ ./create_tpu_service_account.sh
  3. コーディネーター VM を作成します。

    $ ./create_cpu.sh

    このスクリプトは、起動スクリプトを使用して VM に依存関係をインストールし、起動スクリプトが完了するまで自動的にブロックします。

  4. ローカルコードをコーディネーター VM にデプロイします。

    $ ./deploy.sh
  5. VM に SSH で接続する:

    $ gcloud compute ssh $USER-admin -- -L8265:localhost:8265

    Ray はポート 8265 でダッシュボードを自動的に起動するため、ここではポート転送が有効になっています。コーディネーター VM に SSH 接続したマシンから、http://127.0.0.1:8265/ でこのダッシュボードにアクセスできます。

  6. ステップ 0 をスキップした場合は、CPU VM 内で gcloud 認証情報を設定します。

    $ (vm) gcloud auth login --update-adc

    このステップでは、プロジェクト ID 情報を設定し、Cloud TPU API がコーディネーター VM で実行できるようにします。

  7. インストール要件:

    $ (vm) pip3 install -r src/requirements.txt
  8. コーディネーター VM で Ray を起動すると、コーディネーター VM が Ray クラスタのヘッドノードになります。

    $ (vm) ray start --head --port=6379 --num-cpus=0

使用例

基本的な JAX の例

run_basic_jax.py は、TPU VM を使用する Ray クラスタで Ray Jobs と Ray ランタイム環境を使用して JAX ワークロードを実行する方法を示した最小限の例です。

JAX や PyTorch / XLA PJRT などのマルチコントローラ プログラミング モデルを使用する Cloud TPU と互換性のある ML フレームワークの場合は、ホストごとに少なくとも 1 つのプロセスを実行する必要があります。詳細については、マルチプロセス プログラミング モデルをご覧ください。実際には、以下のようになります。

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"

v4-128 のように 16 台を超えるホストがある場合は、SSH のスケーラビリティの問題が発生して、コマンドを次のように変更する必要がある可能性があります。

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8

my_bug_free_python_code にバグが含まれている場合、これがデベロッパーの作業の進捗を妨げる可能性があります。この問題を解決する方法の一つは、Kubernetes や Ray などのオーケストレーターを使用することです。Ray には、ランタイム環境のコンセプトが含まれています。この環境を適用すると、Ray アプリケーションの実行時にコードと依存関係がデプロイされます。

Ray ランタイム環境を Ray クラスタと Ray Jobs と組み合わせることで、SCP / SSH サイクルを回避できます。上記の例に沿ったと仮定すると、次のコマンドで実行できます。

$ python3 legacy/run_basic_jax.py

出力は次のようになります。

2023-03-01 22:12:10,065   INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379...
2023-03-01 22:12:10,072   INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265
W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu...
Creating TPU:  $USER-ray-test
Request:  {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}}
Create TPU operation still running...
...
Create TPU operation complete.
I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster...

I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host.
2023-03-01 22:15:18,010   INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip.
2023-03-01 22:15:18,010   INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'.
2023-03-01 22:15:18,080   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload.
I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs.
...
I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED
I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089   21286 credentials_generic.cc:35]            Could not get HOME environment variable.
8

I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED
I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499    8561 credentials_generic.cc:35]            Could not get HOME environment variable.
8

フォールト トレラント トレーニング

この例では、RayTpuController を使用してフォールト トレラントのトレーニングを実装する方法を示します。この例では、v4-16 の PAX で単純な LLM を事前トレーニングしますが、この PAX ワークロードは、長時間実行される他のワークロードに置き換えることができます。ソースコードについては、run_pax_autoresume.py をご覧ください。

この例を実行するには:

  1. paxml のクローンをコーディネーター VM に作成します。

    $ git clone https://github.com/google/paxml.git

    Ray ランタイム環境で JAX の変更とデプロイを簡単に行えることを示すため、この例では PAX を変更する必要があります。

  2. 新しいテスト構成を追加します。

    $ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py
    
    @experiment_registry.register
    class TestModel(LmCloudSpmd2BLimitSteps):
    ICI_MESH_SHAPE = [1, 4, 2]
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ
    
    def task(self) -> tasks_lib.SingleTask.HParams:
      task_p = super().task()
      task_p.train.num_train_steps = 1000
      task_p.train.save_interval_steps = 100
      return task_p
    EOT
    
  3. run_pax_autoresume.py を実行します。

    $ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
  4. ワークロードの実行中に、デフォルトで $USER-tpu-ray という名前の TPU を削除した場合の影響をテストします。

    $ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b

    Ray は TPU が停止していることを検出し、次のメッセージを表示します。

    I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata.
    W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu...
    2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.

    ジョブは TPU VM を自動的に再作成し、最新のチェックポイント(この例では 200 ステップ)からトレーニングを再開できるようにトレーニング ジョブを再起動します。

    I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting...
    I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`...
    I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`...
    I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally

この例では、Ray AIR の Ray Tune を使用して、JAX / FLAX の MNIST をハイパーパラメータ チューニングします。ソースコードについては、run_hp_search.py をご覧ください。

この例を実行するには:

  1. 要件をインストールします:

    $ pip3 install -r src/tune/requirements.txt
  2. run_hp_search.py を実行します。

    $ python3 src/tune/run_hp_search.py

    出力は次のようになります。

    Number of trials: 3/3 (3 TERMINATED)
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    | Trial name                  | status     | loc               |   learning_rate |   momentum |    acc |   iter |   total time (s) |
    |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------|
    | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 |     1.15258e-09 |   0.897988 | 0.0982 |      3 |          82.4525 |
    | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 |     0.000219523 |   0.825463 | 0.1009 |      3 |          73.1168 |
    | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 |     1.08035e-08 |   0.660416 | 0.098  |      3 |          71.6813 |
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    
    2023-03-02 21:50:47,378   INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop).
    ...

トラブルシューティング

Ray ヘッドノードが接続できない

TPU ライフサイクルを作成または削除するワークロードを実行すると、TPU ホストが Ray クラスタから切断されないことがあります。これは、Ray ヘッドノードが一連の IP アドレスへの接続に失敗したことを示す gRPC エラーとして表示される場合があります。

そのため、Ray セッションを終了(ray stop)して再起動(ray start --head --port=6379 --num-cpus=0)する必要がある場合があります。

Ray Job がログ出力なしですぐに失敗する

PAX は試験運用版で、この例では pip の依存関係が原因で中断する可能性があります。この場合、次のようなメッセージが表示されます。

I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs.
I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures.
I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2})
I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED.
2023-03-03 20:51:38,641   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
2023-03-03 20:51:38,706   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.

エラーの根本原因を確認するには、http://127.0.0.1:8265/ に移動し、実行中ジョブまたは失敗したジョブのダッシュボードを表示します。これにより、詳細情報が提供されます。runtime_env_agent.log には、runtime_env の設定に関連するすべてのエラー情報が表示されます。次に例を示します。

60    INFO: pip is looking at multiple versions of  to determine which version is compatible with other requirements. This could take a while.
61    INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while.
62    ERROR: Cannot install paxml because these package versions have conflicting dependencies.
63
64    The conflict is caused by:
65        praxis 0.3.0 depends on t5x
66        praxis 0.2.1 depends on t5x
67        praxis 0.2.0 depends on t5x
68        praxis 0.1 depends on t5x
69
70    To fix this you could try to:
71    1. loosen the range of package versions you've specified
72    2. remove package versions to allow pip attempt to solve the dependency conflict
73
74    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts