GKE で TPU と JetStream を使用して Gemma をサービングする

このチュートリアルでは、Google Kubernetes Engine(GKE)で Tensor Processing Unit(TPU)を使用して Gemma 大規模言語モデル(LLM)をサービングする方法について説明します。JetStreamMaxText を含むビルド済みコンテナを GKE にデプロイします。また、実行時に Cloud Storage から Gemma 7B の重みを読み込むように GKE を構成します。

このチュートリアルは、LLM の提供に Kubernetes コンテナ オーケストレーション機能を使用する ML エンジニア、プラットフォーム管理者、オペレーター、データおよび AI スペシャリストを対象としています。Google Cloud のコンテンツで使用されている一般的なロールとタスクの例の詳細については、一般的な GKE ユーザーのロールとタスクをご覧ください。

このページを読む前に、次のことをよく理解しておいてください。

背景

このセクションでは、このチュートリアルで使用されている重要なテクノロジーについて説明します。

Gemma

Gemma は、オープン ライセンスでリリースされ一般公開されている、軽量の生成 AI モデルのセットです。これらの AI モデルは、アプリケーション、ハードウェア、モバイル デバイス、ホスト型サービスで実行できます。Gemma モデルはテキスト生成に使用できますが、特殊なタスク用にチューニングすることもできます。

詳しくは、Gemma のドキュメントをご覧ください。

TPU

TPU は、Google が独自に開発した特定用途向け集積回路(ASIC)であり、TensorFlowPyTorchJAX などのフレームワークを使用して構築された ML モデルと AI モデルを高速化するために使用されます。

このチュートリアルでは、Gemma 7B モデルのサービングについて説明します。GKE は、低レイテンシでプロンプトをサービングするモデルの要件に基づいて構成された TPU トポロジを使用して、単一ホストの TPUv5e ノードにモデルをデプロイします。

JetStream

JetStream は、Google が開発したオープンソースの推論サービング フレームワークです。JetStream を使用すると、TPU と GPU で高性能、高スループット、メモリ最適化された推論が可能になります。継続的なバッチ処理や量子化技術などの高度なパフォーマンス最適化により、LLM のデプロイを容易にします。JetStream では、PyTorch / XLA と JAX TPU のサービングにより、最適なパフォーマンスを実現できます。

これらの最適化の詳細については、JetStream PyTorchJetStream MaxText のプロジェクト リポジトリをご覧ください。

MaxText

MaxText は、FlaxOrbaxOptax などのオープンソースの JAX ライブラリ上に構築された、パフォーマンス、スケーラビリティ、適応性に優れた JAX LLM 実装です。MaxText のデコーダ専用の LLM 実装は Python で記述されています。XLA コンパイラの活用により、カスタム カーネルを構築しなくても高いパフォーマンスを実現できます。

MaxText がサポートする最新のモデルとパラメータ サイズの詳細については、MaxtText プロジェクト リポジトリをご覧ください。