コンテンツに移動
Containers & Kubernetes

Google Kubernetes Engine で NVIDIA GPU を使用して JAX マルチノード アプリケーションを実行する方法

2023年3月29日
https://storage.googleapis.com/gweb-cloudblog-publish/images/containers_2022_anH39my.max-2500x2500.jpg
Google Cloud Japan Team

※この投稿は米国時間 2023 年 3 月 23 日に、Google Cloud blog に投稿されたものの抄訳です。

JAX は、高パフォーマンスの数値計算および機械学習(ML)研究を対象とした、急成長中の Python ライブラリです。大規模言語モデル、創薬、物理 ML、強化学習、ニューラル グラフィックスに応用できる JAX は、ここ数年間で信じられないほど幅広い領域に導入されてきました。JAX が開発者と研究者にもたらすメリットは、使いやすい NumPy API、自動微分、最適化など、数多くあります。また、JAX を使用すると、マルチノード / マルチ GPU の複数のシステムに対する分散処理を数行のコードで実現し、NVIDIA GPU の XLA 最適化カーネルを介してパフォーマンスを向上させることができます。

ここでは、JAX のマルチ GPU / マルチノード アプリケーションを、NVIDIA A100 80GB Tensor Core GPU を搭載した A2 Ultra マシンシリーズを使用して GKE(Google Kubernetes Engine)で実行する方法をご紹介します。それぞれ 8 個のプロセス、8 個の GPU を備えた 4 個のノードでシンプルな Hello World アプリケーションを実行します。

前提条件

  1. gcloud init を実行してプロンプトの指示に従い、gcloud のインストールと環境設定を行っていること

  2. Docker をインストールし、gcloud 認証情報ヘルパーを使用して Google Container Registry にログインしていること

  3. GCP 向けの kubectlkubectl 認証プラグインをインストールしていること

GKE クラスタの設定

リポジトリのクローンを作成します。

読み込んでいます...

必要な API を有効にします。

読み込んでいます...

デフォルトの VPC を作成します(まだ作成していない場合)。

読み込んでいます...

クラスタ(コントロール ノード)を作成します。us-central1-c は、実際の優先ゾーンで置き換えます。

読み込んでいます...

プール(コンピューティング ノード)を作成します。--enable-fast-socket --enable-gvnic はマルチノードのパフォーマンスに必須です。--preemptible により割り当ての必要がなくなりますが、ノードがプリエンプティブになります。これが望ましくない場合は、このフラグを削除します。us-central1-c は、実際の優先ゾーンで置き換えます。これには数分かかることがあります。

読み込んでいます...

コンピューティング ノードに NVIDIA CUDA ドライバをインストールします。

読み込んでいます...

コンテナをビルドして、レジストリに push します。これにより、コンテナが gcr.io/<your project>/jax/hello:latest に push されます。これには数分かかることがあります。

読み込んでいます...

kubernetes/job.yaml と kubernetes/kustomization.yaml で、<<PROJECT>> を実際の GCP プロジェクト名に変更します。

JAX の実行

コンピューティング ノードで JAX アプリケーションを実行します。これにより、それぞれが 1 個の NVIDIA GPU で 1 個の JAX プロセスを実行する 32 個の Pod(ノードあたり 8 個)が作成されます。

読み込んでいます...

次のコマンドを実行します。

読み込んでいます...

これにより、ステータスを確認できます。ステータスは、ContainerCreating から Pending(数分後)、Running、そして最後に Completed へと変化します。

ジョブが完了したら、kubectl のログで 1 個の Pod からの出力を確認します。

読み込んでいます...

このアプリケーションは [1.0] と等しい長さ 1 の配列を各プロセスに作成してから、それらすべてを累積させます。32 個のプロセスの場合、出力は各プロセスで [32.0] となります。

ここまで、GKE で 32 個の NVIDIA A100 GPU を使用して JAX を実行する方法をご説明しました。次は、TensorRT と NVIDIA T4 GPU を使用して大規模な推論を実行する方法を学びましょう。


今回のブログ投稿に際して、専門知識とガイダンスを提供してくれた Google の機械学習ソリューション アーキテクト Jarek Kazmierczak と NVIDIA のシステム ソフトウェア エンジニア Iris Liu 氏に心から感謝いたします。


- NVIDIA ソフトウェア エンジニア Leopold Cambier 氏
- クラウド ソリューション アーキテクト Roberto Barbero

投稿先