Ray on GKE で Cloud TPU をよりネイティブに利用
Nisha Mariam Johnson
Product Manager
Ryan O'Leary
Software Engineer
※この投稿は米国時間 2025 年 11 月 4 日に、Google Cloud blog に投稿されたものの抄訳です。
エンジニアリング チームは、GPU と Cloud TPU の両方を含む幅広いハードウェアで AI ワークロードをスケーリングするために Ray を使用しています。Ray はコアとなるスケーリング機能を提供する一方、開発者は多くの場合、各アクセラレータの固有のアーキテクチャの詳細を管理してきました。Cloud TPU には、その特定のネットワーキング モデルと、単一プログラム複数データ(SPMD)プログラミング スタイルが含まれます。
Google は、Anyscale とのパートナーシップの一環として、Google Kubernetes Engine(GKE)で TPU を使用する際のエンジニアリング作業を削減する取り組みを進めています。その目標は、TPU での Ray の使用をできるだけネイティブで低摩擦なものにすることです。
本日、Google はそれを可能にするための重要な改善をいくつかリリースします。
Ray TPU ライブラリで、Ray Core における TPU の認識とスケーリングを改善
TPU には、独自のアーキテクチャと、SPMD と呼ばれる特定のプログラミング スタイルがあります。大規模な AI ジョブは、チップ間相互接続(ICI)と呼ばれる高速ネットワーキングで接続されたチップの集合体である TPU スライスで実行されます。


以前は、この特定のハードウェア トポロジを認識するように Ray を手動で設定する必要がありました。これは重要なセットアップ手順であり、正しく行われなければ、ジョブが接続されていない異なるスライスからリソースを断片的に取得し、深刻なパフォーマンス ボトルネックを引き起こす可能性がありました。
この新しいライブラリ ray.util.tpu では、ユーザーがこれらのハードウェアの詳細を設定する必要がなくなりました。SlicePlacementGroup という機能と新しい label_selector API を使用して、コロケーションされた TPU スライス全体を 1 つのアトミック ユニットとして自動的に予約します。これにより、ジョブは統合されたハードウェアで実行されることが保証され、断片化によるパフォーマンスの問題を回避できます。Ray ではこれまで、この単一スライスのアトミック性を保証できなかったため、信頼性の高い真のマルチスライス トレーニング(意図的に複数のユニークなスライスにまたがる)を構築することは不可能でした。この新しい API は、Ray ユーザーがマルチスライス テクノロジーを使用して複数の TPU スライスでスケーリングするための重要な基盤も提供します。
Jax、Ray Train、Ray Serve のサポートを拡大
Google の開発は、トレーニングと推論の両方に関わっています。トレーニングに関して、Ray Train は TPU 上の JAX(JaxTrainer 経由)と PyTorch のアルファ版サポートを提供しています。
JaxTrainer API を使用すると、マルチホスト TPU での JAX ワークロードの実行が簡素化されます。複雑な分散ホストの初期化を自動的に処理するようになりました。以下のコード例に示すように、ワーカー数、トポロジ、アクセラレータ タイプなどのハードウェア要件を、シンプルな ScalingConfig オブジェクト内で定義するだけで済みます。残りの部分は JaxTrainer が行います。
これは、リソースの断片化という重大なパフォーマンス上の問題を解決する、大きな改善点です。以前は、「4x4」トポロジをリクエストするジョブ(スライスと呼ばれる単一のコロケーション ハードウェア ユニットで実行する必要がある)が、代わりに断片化されたリソースを受け取ることがありました。たとえば、1 つの物理スライスから 8 個のチップ、別の接続されていないスライスから 8 個のチップなどです。この断片化は、単一の統合されたスライス内にのみ存在する高速 ICI 相互接続をワークロードが使用できないため、大きなボトルネックとなっていました。
JaxTrainer がマルチホスト TPU でのトレーニングを簡素化する例:
Ray Serve API は TPU をサポートしており、vLLM TPU の改善により、TPU に移行する際も vLLM で Ray を引き続き使用できます。これにより、GPU で使用しているのと同じスタックを、最小限のコード変更で TPU で実行できます。
ラベルベースのスケジューリング API で簡単に取得可能
新しいラベルベースのスケジューリング API は、GKE カスタム コンピューティング クラスと統合されています。カスタム コンピューティング クラスは、名前付きのハードウェア構成を定義する簡単な方法です。たとえば、cost-optimized というクラスを作成して、GKE にまず Spot インスタンスの取得を試み、次に Dynamic Workload Scheduler FlexStart インスタンスにフォールバックし、最終的に最後の手段として予約インスタンスにフォールバックするように指示できます。新しい Ray API では、Python からクラスを直接使用できます。シンプルな label_selector を使用して、「TPU-V6E」などのハードウェアをリクエストしたり、費用対効果が最適化されたクラスをターゲットにしたりでき、これらすべては個別の YAML ファイルを管理することなく行えます。
この同じ label_selector メカニズムは、TPU の詳細なハードウェア制御も公開します。GKE は、スライス用の TPU Pod をプロビジョニングする際に、メタデータ(ワーカーランクやトポロジなど)を各 Pod に挿入します。KubeRay(GKE 上の Ray を管理)は、GKE が提供するこのメタデータを読み取り、ノードの作成時に自動的に Ray 固有のラベルに変換します。これにより、TPU の世代(ray.io/accelerator-type)、物理チップのトポロジ(ray.io/tpu-topology)、スライス内のワーカーランク(ray.io/tpu-worker-id)などの重要な情報が提供されます。
これらのノードラベルを使用すると、Ray の label_selector を使用して、SPMD ワークロードを特定のコロケーション ハードウェア(「4x4」トポロジや特定のワーカーランクなど)に固定できます。
以下の例では、Ray ユーザーが v6e-32 TPU スライスをリクエストしていますが、GKE にカスタム コンピューティング クラスを使用して、v6e-32 が利用できない場合は v5e-16 にフォールバックするように指示しています。同様に、ユーザーはスポット リソースまたは DWS リソースをリクエストすることから始め、それらが利用できない場合は、予約インスタンスにフォールバックできます。
TPU の指標とログを 1 か所に
TensorCore 使用率、デューティ サイクル、高帯域幅メモリ(HBM)使用率、メモリ帯域幅使用率などの主要な TPU パフォーマンス指標を、Ray ダッシュボードで直接確認できるようになりました。また、低レベルの libtpu ログも追加しました。これにより、コードが原因で障害が発生したのか、TPU ハードウェア自体が原因で障害が発生したのかをすぐに確認できるため、デバッグが大幅に高速化されます。
使ってみる
これらのアップデートは、TPU を Ray エコシステムにシームレスに組み込む、大きな一歩です。これにより、既存の Ray アプリケーションを GPU と TPU の間で適応させるプロセスがはるかに分かりやすいものになります。詳細とご利用開始方法は次のとおりです。
-
ドキュメントを読む:
-
JAX ワークロード: JaxTrainer の使用方法については、新しいJAX を使ってみるガイドをご覧ください。JaxTrain の詳細もご覧ください。
-
TPU 指標: TPU 指標を Ray ダッシュボードまたは Grafana で表示
-
TPU 容量のリクエスト: 7 日未満で実行されるジョブに TPU へのアクセスを提供する TPU 向け DWS Flex Start を使用して、すぐに開始できます。
-
関連コンテンツ: TPU の概要
-Nisha Mariam Johnson、プロダクト マネージャー
-Ryan O'Leary、ソフトウェア エンジニア
