Ray を使用して ML ワークロードをスケーリングする

はじめに

Cloud TPU Ray ツールは、Cloud TPU APIRay Jobs を組み合わせて、Cloud TPU でのユーザーの開発エクスペリエンスを改善することを目的としています。このユーザーガイドでは、Cloud TPU で Ray を使用する方法の最小限の例を説明します。この例は、本番環境サービスで使用されるものではなく、例示のみを目的としています。

このツールの内容

便宜上、このツールには以下が用意されています。

  • TPU の一般的なアクションのために、ボイラープレートを隠す一般的な抽象化
  • 独自の基本的なワークフローにフォークできる例

特に、以下の点に注意してください。

  • tpu_api.py: Cloud TPU API を使用した基本的な TPU オペレーション用の Python ラッパー。
  • tpu_controller.py: TPU のクラス表現。これは、本質的に tpu_api.py のラッパーです。
  • ray_tpu_controller.py: Ray 機能を備えた TPU コントローラ。これにより、Ray クラスタと Ray ジョブのボイラープレートが抽象化されます。
  • run_basic_jax.py: print(jax.device_count())RayTpuController を使用する方法を示す基本的な例。
  • run_hp_search.py: MNIST の JAX/Flax で Ray Tune を使用する方法を示す基本的な例。
  • run_pax_autoresume.py: ワークロードの例として PAX を使用して、フォールト トレラント トレーニングで RayTpuController を使用する方法を示す例。

Ray クラスタ ヘッドノードの設定

TPU Pod で Ray を使用する基本的な方法の 1 つは、TPU Pod を Ray クラスタとして設定することです。コーディネーターとして別の CPU VM を作成するのが自然な方法です。次の図に、Ray クラスタ構成の例を示します。

Ray クラスタの構成例

次のコマンドは、Google Cloud CLI を使用して Ray クラスタを設定する方法を示しています。

$ gcloud compute instances create my_tpu_admin --machine-type=n1-standard-4 ...
$ gcloud compute ssh my_tpu_admin

$ (vm) pip3 install ray[default]
$ (vm) ray start --head --port=6379 --num-cpus=0
...
# (Ray returns the IP address of the HEAD node, for example, RAY_HEAD_IP)
$ (vm) gcloud compute tpus tpu-vm create $TPU_NAME ... --metadata startup-script="pip3 install ray && ray start --address=$RAY_HEAD_IP --resources='{\"tpu_host\": 1}'"

便宜上、コーディネーター VM を作成して、このフォルダの内容をコーディネーター VM にデプロイする基本的なスクリプトも提供しています。ソースコードについては、create_cpu.shdeploy.sh をご覧ください。

このスクリプトでは、以下のいくつかのデフォルト値が設定されます。

  • create_cpu.sh は、$USER-admin という名前の VM を作成し、gcloud config のデフォルトに設定されているプロジェクトとゾーンを利用します。これらのデフォルト値を確認するには、gcloud config list を実行します。
  • create_cpu.sh は、デフォルトで 200 GB のブートディスク サイズを割り当てます。
  • deploy.sh は、VM 名が $USER-admin であることを前提としています。create_cpu.sh の値を変更する場合は、deploy.sh で必ず変更してください。

コンビニエンス スクリプトを使用するには:

  1. GitHub リポジトリのクローンをローカルマシンに作成し、ray_tpu フォルダに入力します。

    $ git clone https://github.com/tensorflow/tpu.git
    $ cd tpu/tools/ray_tpu/
    
  2. TPU 管理専用のサービス アカウント(強く推奨)がない場合は、設定します。

    $ ./create_tpu_service_account.sh
    
  3. コーディネーター VM を作成します。

    $ ./create_cpu.sh
    

    このスクリプトは、起動スクリプトを使用して VM に依存関係をインストールし、起動スクリプトが完了するまで自動的にブロックされます。

  4. ローカル コードをコーディネーター VM にデプロイします。

    $ ./deploy.sh
    
  5. VM に SSH で接続する:

    $ gcloud compute ssh $USER-admin -- -L8265:localhost:8265
    

    Ray が自動的にポート 8265 でダッシュボードを開始するため、ポート転送はここで有効になります。コーディネーター VM に SSH 接続したマシンから、http://127.0.0.1:8265/ でこのダッシュボードにアクセスできます。

  6. ステップ 0 をスキップした場合は、CPU VM 内に gcloud 認証情報を設定します。

    $ (vm) gcloud auth login --update-adc
    

    このステップでは、プロジェクト ID 情報を設定し、コーディネーター VM で Cloud TPU API を実行できるようにします。

  7. インストール要件:

    $ (vm) pip3 install -r src/requirements.txt
    
  8. コーディネーター VM で Ray を起動すると、コーディネーター VM が Ray クラスタのヘッドノードになります。

    $ (vm) ray start --head --port=6379 --num-cpus=0
    

使用例

基本的な JAX の例

run_basic_jax.py は、TPU VM を使用する Ray クラスタで Ray Jobs と Ray ランタイム環境を使用して、JAX ワークロードを実行する方法の最低例です。

JAX や PyTorch/XLA PJRT など、マルチコントローラ プログラミング モデルを使用する Cloud TPU と互換性のある ML フレームワークの場合は、ホストごとに少なくとも 1 つのプロセスを実行する必要があります。詳細については、マルチプロセス プログラミング モデルをご覧ください。実際には、次のようになります。

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py"

v4-128 など、約 16 を超えるホストがある場合、SSH のスケーラビリティに関する問題が発生するため、コマンドを次のように変更する必要があります。

$ gcloud compute tpus tpu-vm scp my_bug_free_python_code my_tpu:~/ --worker=all --batch-size=8
$ gcloud compute tpus tpu-vm ssh my_tpu --worker=all --command="python3 ~/my_bug_free_python_code/main.py &" --batch-size=8

これは、my_bug_free_python_code にバグが含まれている場合、デベロッパーの速度の妨げになる可能性があります。この問題を解決する方法の 1 つは、Kubernetes や Ray などのオーケストレーターを使用することです。Ray には、ランタイム環境のコンセプトが含まれており、適用すると、Ray アプリケーションの実行中にコードと依存関係がデプロイされます。

Ray ランタイム環境と Ray クラスタおよび Ray ジョブを組み合わせることで、SCP/SSH サイクルをバイパスできます。上記の例の場合、次のコマンドで実行できます。

$ python3 legacy/run_basic_jax.py

出力は次のようになります。

2023-03-01 22:12:10,065   INFO worker.py:1364 -- Connecting to existing Ray cluster at address: 10.130.0.19:6379...
2023-03-01 22:12:10,072   INFO worker.py:1544 -- Connected to Ray cluster. View the dashboard at http://127.0.0.1:8265
W0301 22:12:11.148555 140341931026240 ray_tpu_controller.py:143] TPU is not found, create tpu...
Creating TPU:  $USER-ray-test
Request:  {'accelerator_config': {'topology': '2x2x2', 'type': 'V4'}, 'runtimeVersion': 'tpu-ubuntu2204-base', 'networkConfig': {'enableExternalIps': True}, 'metadata': {'startup-script': '#! /bin/bash\necho "hello world"\nmkdir -p /dev/shm\nsudo mount -t tmpfs -o size=100g tmpfs /dev/shm\n pip3 install ray[default]\nray start --resources=\'{"tpu_host": 1}\' --address=10.130.0.19:6379'}}
Create TPU operation still running...
...
Create TPU operation complete.
I0301 22:13:17.795493 140341931026240 ray_tpu_controller.py:121] Detected 0 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:13:17.795823 140341931026240 ray_tpu_controller.py:160] Waiting for 30s for TPU hosts to join cluster...
…
I0301 22:15:17.986352 140341931026240 ray_tpu_controller.py:121] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
I0301 22:15:17.986503 140341931026240 ray_tpu_controller.py:90] Ray already started on each host.
2023-03-01 22:15:18,010   INFO dashboard_sdk.py:315 -- Uploading package gcs://_ray_pkg_3599972ae38ce933.zip.
2023-03-01 22:15:18,010   INFO packaging.py:503 -- Creating a file package for local directory '/home/$USER/src'.
2023-03-01 22:15:18,080   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_3599972ae38ce933.zip already exists, skipping upload.
I0301 22:15:18.455581 140341931026240 ray_tpu_controller.py:169] Queued 2 jobs.
...
I0301 22:15:48.523541 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_WRUtVB7nMaRTgK39: Status is SUCCEEDED
I0301 22:15:48.561111 140341931026240 ray_tpu_controller.py:256] [raysubmit_WRUtVB7nMaRTgK39]: E0301 22:15:36.294834089   21286 credentials_generic.cc:35]            Could not get HOME environment variable.
8

I0301 22:15:58.575289 140341931026240 ray_tpu_controller.py:254] [ADMIN]: raysubmit_yPCPXHiFgaCK2rBY: Status is SUCCEEDED
I0301 22:15:58.584667 140341931026240 ray_tpu_controller.py:256] [raysubmit_yPCPXHiFgaCK2rBY]: E0301 22:15:35.720800499    8561 credentials_generic.cc:35]            Could not get HOME environment variable.
8

フォールト トレラント トレーニング

この例では、RayTpuController を使用してフォールト トレラントなトレーニングを実装する方法を示します。この例では、v4-16 の PAX に単純な LLM を事前トレーニングします。ただし、この PAX ワークロードを他の長時間実行ワークロードに置き換えることもできます。ソースコードについては、run_pax_autoresume.py をご覧ください。

この例を実行するには:

  1. コーディネーター VM に paxml のクローンを作成します。

    $ git clone https://github.com/google/paxml.git
    

    Ray Runtime Environment の JAX 変更の使いやすさを説明して、この例では PAX を変更する必要があります。

  2. 新しいテスト構成を追加します。

    $ cat <<EOT >> paxml/paxml/tasks/lm/params/lm_cloud.py
    
    @experiment_registry.register
    class TestModel(LmCloudSpmd2BLimitSteps):
    ICI_MESH_SHAPE = [1, 4, 2]
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_CONTEXT_AND_OUT_PROJ
    
    def task(self) -> tasks_lib.SingleTask.HParams:
      task_p = super().task()
      task_p.train.num_train_steps = 1000
      task_p.train.save_interval_steps = 100
      return task_p
    EOT
    
  3. run_pax_autoresume.py を実行します。

    $ python3 legacy/run_pax_autoresume.py --model_dir=gs://your/gcs/bucket
    
  4. ワークロードを実行する際に、デフォルトで $USER-tpu-ray という TPU を削除した場合の動作を確認します。

    $ gcloud compute tpus tpu-vm delete -q $USER-tpu-ray --zone=us-central2-b
    

    Ray は、次のメッセージを使用して TPU が停止していることを検出します。

    I0303 05:12:47.384248 140280737294144 checkpointer.py:64] Saving item to gs://$USER-us-central2/pax/v4-16-autoresume-test/checkpoints/checkpoint_00000200/metadata.
    W0303 05:15:17.707648 140051311609600 ray_tpu_controller.py:127] TPU is not found, create tpu...
    2023-03-03 05:15:30,774 WARNING worker.py:1866 -- The node with node id: 9426f44574cce4866be798cfed308f2d3e21ba69487d422872cdd6e3 and address: 10.130.0.113 and node name: 10.130.0.113 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    2023-03-03 05:15:33,243 WARNING worker.py:1866 -- The node with node id: 214f5e4656d1ef48f99148ddde46448253fe18672534467ee94b02ba and address: 10.130.0.114 and node name: 10.130.0.114 has been marked dead because the detector has missed too many heartbeats from it. This can happen when a       (1) raylet crashes unexpectedly (OOM, preempted node, etc.)
          (2) raylet has lagging heartbeats due to slow network or busy workload.
    

    その後、ジョブは自動的に TPU VM を再作成してトレーニング ジョブを再起動します。これにより、最新のチェックポイントからトレーニングを再開できます(この例では 200 ステップ)。

    I0303 05:22:43.141277 140226398705472 train.py:1149] Training loop starting...
    I0303 05:22:43.141381 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/train`...
    I0303 05:22:43.353654 140226398705472 summary_utils.py:267] Opening SummaryWriter `gs://$USER-us-central2/pax/v4-16-autoresume-test/summaries/eval_train`...
    I0303 05:22:44.008952 140226398705472 py_utils.py:350] Starting sync_global_devices Start training loop from step: 200 across 8 devices globally
    

この例では、Ray AIR の Ray Tune と JAX/FLAX のハイパーパラメータ調整 MNIST を使用します。ソースコードについては、run_hp_search.py をご覧ください。

この例を実行するには:

  1. 要件をインストールします:

    $ pip3 install -r src/tune/requirements.txt
    
  2. run_hp_search.py を実行します。

    $ python3 src/tune/run_hp_search.py
    

    出力は次のようになります。

    Number of trials: 3/3 (3 TERMINATED)
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    | Trial name                  | status     | loc               |   learning_rate |   momentum |    acc |   iter |   total time (s) |
    |-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------|
    | hp_search_mnist_8cbbb_00000 | TERMINATED | 10.130.0.84:21340 |     1.15258e-09 |   0.897988 | 0.0982 |      3 |          82.4525 |
    | hp_search_mnist_8cbbb_00001 | TERMINATED | 10.130.0.84:21340 |     0.000219523 |   0.825463 | 0.1009 |      3 |          73.1168 |
    | hp_search_mnist_8cbbb_00002 | TERMINATED | 10.130.0.84:21340 |     1.08035e-08 |   0.660416 | 0.098  |      3 |          71.6813 |
    +-----------------------------+------------+-------------------+-----------------+------------+--------+--------+------------------+
    
    2023-03-02 21:50:47,378   INFO tune.py:798 -- Total run time: 318.07 seconds (318.01 seconds for the tuning loop).
    ...
    

トラブルシューティング

Ray head ノードが接続できない

TPU のライフサイクルを作成または削除するワークロードを実行しても、TPU ホストが Ray クラスタから接続解除されないことがあります。Ray ヘッドノードが一連の IP アドレスに接続できないことを示す gRPC エラーが表示されることがあります。

その結果、レイ セッション(ray stop)を終了して再起動(ray start --head --port=6379 --num-cpus=0)する必要が生じる場合があります。

Ray Job がログ出力なしで直接失敗する

PAX は試験運用版であり、この例は pip の依存関係が原因で破損する可能性があります。この場合、次のような内容が表示されます。

I0303 20:50:36.084963 140306486654720 ray_tpu_controller.py:174] Queued 2 jobs.
I0303 20:50:36.136786 140306486654720 ray_tpu_controller.py:238] Requested to clean up 1 stale jobs from previous failures.
I0303 20:50:36.148653 140306486654720 ray_tpu_controller.py:253] Job status: Counter({<JobStatus.FAILED: 'FAILED'>: 2})
I0303 20:51:38.582798 140306486654720 ray_tpu_controller.py:126] Detected 2 TPU hosts in cluster, expecting 2 hosts in total
W0303 20:51:38.589029 140306486654720 ray_tpu_controller.py:196] Detected job raysubmit_8j85YLdHH9pPrmuz FAILED.
2023-03-03 20:51:38,641   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.
2023-03-03 20:51:38,706   INFO dashboard_sdk.py:362 -- Package gcs://_ray_pkg_ae3cacd575e24531.zip already exists, skipping upload.

エラーの根本原因を確認するには、http://127.0.0.1:8265/ に移動し、実行中または失敗したジョブのダッシュボードを表示して、詳細情報を確認できます。runtime_env_agent.log は、runtime_env の設定に関連するすべてのエラー情報を表示します。次に例を示します。

60    INFO: pip is looking at multiple versions of  to determine which version is compatible with other requirements. This could take a while.
61    INFO: pip is looking at multiple versions of orbax to determine which version is compatible with other requirements. This could take a while.
62    ERROR: Cannot install paxml because these package versions have conflicting dependencies.
63
64    The conflict is caused by:
65        praxis 0.3.0 depends on t5x
66        praxis 0.2.1 depends on t5x
67        praxis 0.2.0 depends on t5x
68        praxis 0.1 depends on t5x
69
70    To fix this you could try to:
71    1. loosen the range of package versions you've specified
72    2. remove package versions to allow pip attempt to solve the dependency conflict
73
74    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts