コンテンツに移動
AI & 機械学習

Google Cloud での Gemma のパフォーマンスに関する詳細

2024年4月22日
https://storage.googleapis.com/gweb-cloudblog-publish/images/Next24_Blog_blank_2-05.max-2500x2500.jpg
Google Cloud Japan Team

※この投稿は米国時間 2024 年 4 月 11 日に、Google Cloud blog に投稿されたものの抄訳です。

Google は今年の初めに Gemma 発表しました。Gemma は、開発者が Google Cloud で迅速にテストを実施し、適応、製品化ができるように構築されたオープン ウェイトのモデル ファミリーです。Gemma モデルはノートパソコンやワークステーションで実行できるほか、Cloud GPU または Cloud TPU を選択し、Vertex AI または Google Kubernetes EngineGKE)を介して Google Cloud で実行できます。これには、Cloud GPU 上の vLLMHuggingFace TGITensorRT LLM に加え、Cloud TPU 上の JetStream Hugging Face TGIOptimum-TPU)を活用し、PyTorch JAX を使用したトレーニング、ファインチューニング、推論が含まれます。

Google のベンチマークでは、Cloud TPU v5e を使用した Gemma モデルのトレーニング パフォーマンスは、ベースラインである Llama-2 のトレーニング パフォーマンスと比較して、最大 3 倍(料金あたりのパフォーマンス向上)であることが示されました。Google は今週初め、コスト効率とパフォーマンスに優れた新しい推論エンジンである JetStream をリリースしました。Cloud TPU での Gemma 推論のパフォーマンスを分析した結果、JetStream Gemma のサービングを行った場合、ベースラインとした先行 TPU 推論スタックと比較して、LLM 推論の推論効率が 3 倍向上する(料金あたりの推論回数が増える)ことがわかりました。

この投稿では、Google Cloud アクセラレータでの Gemma モデルのトレーニングと推論のパフォーマンスについて説明します。ご紹介する結果は 2024 4 月時点のスナップショットです。インフラストラクチャの効率とモデルの品質は、オープンソース コミュニティ、企業ユーザー、Google の各チームの貢献によってさらに進化し、改善されていくと予想されます。

背景: Gemma モデル アーキテクチャの詳細

Gemma ファミリーのモデルには、Gemma 2B Gemma 7B(高密度デコーダ アーキテクチャ)の 2 種類があります。Google は、Gemma 2B モデルで 2 兆個のトークン、7B モデルで 6 兆個のトークンを使い、コンテキストの長さを 8,192 トークンとして事前トレーニングしました。どちらのモデルも 256 のヘッド次元を使用し、両方のバリアントで Rotary Positional EmbeddingsRoPEを利用しています。

モデル

d_model

q_heads

kv_heads

d_ff

n_layers

Gemma 2B

2,048

8

1

16,384

18

Gemma 7B

3,072

16

16

24,576

28

Gemma 7B モデルがマルチヘッド アテンション メカニズムを活用しているのに対し、Gemma 2B マルチクエリ アテンションを利用しています。このアプローチは、推論プロセス中に必要なメモリ帯域幅を削減するのに役立ち、メモリ帯域幅が制限されがちな Gemma 2B のオンデバイス推論シナリオに有利に働く可能性があります。

Gemma トレーニングのパフォーマンス

特定モデル、あるいは同規模のモデルのカテゴリのトレーニング インフラストラクチャを評価するには、1)モデルの実効 FLOP 使用率、2)料金あたりの相対的パフォーマンスという 2 つの重要な側面を考慮する必要があります。

モデルの実効 FLOP 使用率

モデルの FLOP 使用率(MFU)とは、モデルのスループット、すなわち、基礎となるトレーニング インフラストラクチャのピーク時のスループットに対する、モデルによって実行される実際の 1 秒あたりの浮動小数点演算の比率です。モデルのスループットを計算するために、トレーニング ステップあたりの浮動小数点演算回数とステップ時間の分析推定値を使用します(参照: PaLM)。混合精度トレーニング設定(Int8)に適用された場合、結果として得られる指標はモデルの実効 FLOP 使用率(EMFU)と呼ばれます。他の条件がすべて同じである場合、(EMFU が高いほど、ユニットコストあたりのパフォーマンスが向上していることを示します。MFU の向上は、トレーニングのコスト削減に直結します。

Gemma トレーニングの設定

Gemma モデルの事前トレーニングは、TPU v5e を使って Google 社内で行われました。Gemma 2B では 2 つの v5e-256 を、Gemma 7B では 16 Cloud TPU v5e-256 を採用しました。

Google は、Cloud TPU での Gemma モデルの(EMFU を測定しました。Cloud TPU v5e Cloud TPU v5p はどちらも(この投稿を書いている時点では)最新の Cloud TPU 世代なので、両方のパフォーマンスをご紹介します。Cloud TPU v5e は、料金あたりのパフォーマンスにおいて、これまでで最も費用対効果が高い TPU です。一方、Cloud TPU v5p は、混合エキスパートや、大規模なランキングやレコメンデーションのシステムのような代替ワークロードなど、より複雑な LLM アーキテクチャに対応する、最もパワフルでスケーラブルな TPU です。

下のグラフは、bf16 精度や(AQT による)混合精度(int8)を使用した、Gemma 2B Gemma 7B のトレーニング実行の EMFU を示しています。

https://storage.googleapis.com/gweb-cloudblog-publish/images/0-Gemma-Model-Architecture.max-900x900.png
https://storage.googleapis.com/gweb-cloudblog-publish/images/1-EMFU-Gemma.max-1700x1700.png

Gemma-2b と 7b におけるモデルの実効 FLOP 使用率。TPU v5e-256 と v5p-128 で MaxText を使用して測定。コンテキストの長さは 8,192。2024 年 2 月時点。

これらの結果は、MaxText のリファレンス実装を使用して導き出されました。また、Hugging Face Transformers を使用した、Gemma モデルのトレーニングとファインチューニングのための実装も提供しています。

MaxText を使用した高パフォーマンス トレーニングの実現

Google は、モデル アーキテクチャの違い、コンテキストの長さなどのトレーニング パラメータの違い、基礎となるクラスタのスケールの違いがあることから、モデルタイプ間でトレーニング インフラストラクチャのパフォーマンスを比較することが困難であることを認識しています。Llama 2 に関して公開された結果(合計トークン数と GPU 時間)を、Gemma 7B のトレーニングとの比較のためのベースラインとして選択した理由は次のとおりです。

  1. Gemma 7B とのモデル アーキテクチャの類似性

  2. Gemma 7B はコンテキストの長さを 2 倍にしてトレーニングされていることから、この比較では Llama 2 ベースラインが有利になる
https://storage.googleapis.com/gweb-cloudblog-publish/images/2-Gemma-7B-Training.max-1700x1700.png

Gemma-7b とベースラインの 1 ドルあたりの相対的なトレーニング パフォーマンス。TPU v5e-256 と v5p-128 で Gemma 7B(MaxText)を使用して測定。コンテキストの長さは 8,192。ベースライン(LLama2-7b)パフォーマンスは、公開されている結果に従い、合計 GPU 時間と合計トレーニング トークン数を使用して導き出されました。1 ドルあたりのパフォーマンスは、それぞれのアクセラレータの正規価格を使用して算出されています。2024 年 2 月時点。

1 ドルあたりのパフォーマンスは、(peak-flops*EMFU/VM インスタンスの正規価格)を使用して算出しました。MaxText のリファレンス実装を使用した場合、Gemma 7B モデルでは、ベースラインのトレーニング パフォーマンス(Llama2 7B)に対して、1 ドルあたりのパフォーマンスが最大 3 倍向上しました。ここで示されたパフォーマンスまたは 1 ドルあたりのパフォーマンスの違いは、モデル アーキテクチャ、ハイパーパラメータ、基礎となるアクセラレータ、トレーニング ソフトウェアの作用です。パフォーマンスの向上は、これらの要因のいずれか一つだけに起因するものではありません。

Gemma 推論パフォーマンス

LLM 推論は、多くの場合メモリ基準ですが、トレーニングは大規模な並列処理のメリットを得ることができます。推論は、プレフィルとデコードという、それぞれ異なるコンピューティング特性を持つ 2 つのフェーズで構成されます。プレフィル フェーズは、(トークン数 > ピーク FLOP / HBM 帯域幅の場合)計算依存型体制で動作できますが、デコード フェーズは自己回帰的であり、効率的にバッチ処理されない限り、メモリ依存型になる傾向があります。デコード フェーズでは一度に 1 つのトークンを処理するため、メモリ依存の領域を回避するためのバッチサイズが大きくなる傾向があります。したがって、(プレフィルとデコードの両方で)バッチサイズ全体を単純に大きくすることは最適ではない可能性があります。スループットとレイテンシ、そしてプレフィル長とデコード長の相互作用のため、ここでは入力(プレフィル)と出力(デコード)を別々に扱い、以下では出力トークンに焦点を当てます。

次に、観察結果を説明するために、1 ドルあたりのスループットを指標として使用します。1 ドルあたりのスループットは、モデルサーバーがユーザーからのすべてのリクエストにわたって生成できる、1 秒あたりの出力トークンの数を表すからです。これはグラフの Y 軸で、100 万単位の出力トークン数で測定されます。さらにこの数字を、特定リージョンで Cloud TPU v5e を使用した場合の Compute Engine CUD 料金で割ります。

Cloud TPU v5e JetStream を使用した TPU 推論のコスト効率の向上

スループット、費用、レイテンシは、モデルのサイズ、アクセラレータのタイプ、モデル アーキテクチャの種類、使用される精度の形式など、多くの要因に影響を受ける可能性があるため、推論パフォーマンスの測定が困難です。そのため、ベースライン TPU 推論スタックと比較した JetStream のパフォーマンスを測定する指標として、コスト効率(100 万トークンあたりのコスト)を使用しました。ベースライン TPU 推論スタックと比較して、TPU 推論向けに最適化された JetStream スタックでは、下のグラフに示すように(低いほど良い)、コスト効率が最大 3 倍向上することが確認されました。

https://storage.googleapis.com/gweb-cloudblog-publish/images/3-Gemma_7B_TPU_Inference_Performance_Relat.max-2000x2000.png

ベースライン TPU 推論スタックと比較した、100 万トークンあたりの JetStream の費用。Google 内部データ。TPU v5e-8 で Gemma 7B(MaxText)を使用して測定。特定のリクエスト率とバッチサイズで、入力の長さ 1,024、出力の長さ 1,024。連続的なバッチ処理、重みとアクティベーションの int8 量子化、KV キャッシュ。2024 年 4 月時点。

JetStream TPU 推論による Gemma 7B の大規模なサービング、1 ドルあたりの高スループット

また、JetStream スタックを使った Gemma 7B の大規模なサービングのパフォーマンスを観察し、ベースライン TPU 推論スタックと比較しました。このテストの一環として、これらの TPU 推論スタックに送信されるリクエスト率を 1 秒あたりリクエスト数 1256 の間で変化させ、可変長の入出力トークンで Gemma 7B のサービングを行った場合の 1 ドルあたりのスループットを測定しました。JetStream Gemma 7B のサービングを行った場合の 1 ドルあたりのスループットは、リクエスト率が最も高い場合でも、ベースラインより高いという一貫した動作が観察されました。

https://storage.googleapis.com/gweb-cloudblog-publish/images/4-Gemma_7B_TPU_Inference_Relative_Throughp.max-2100x2100.png

ベースライン TPU 推論スタックと比較した JetStream の 1 ドルあたりのスループット(1 ドルあたり 100 万トークン)。Google 内部データ。TPU v5e-8 で Gemma 7B(MaxText)を使用して測定。1~256 の異なるリクエスト率で、入力の長さ 1,024、出力の長さ 1,024。連続的なバッチ処理、重みとアクティベーションの int8 量子化、KV キャッシュ。2024 年 4 月時点。

1 ドルあたりのスループットと 100 万トークンあたりのコストを測定する

Google Kubernetes EngineGKE)の JetStream コンテナを使用して、テストをオーケストレーションしました。入力データセットには可変長の入力と出力が含まれているため、実世界の言語モデルの入力トラフィックを模倣しています。グラフを生成するために、JetStream Gemma モデルをデプロイし、モデルのエンドポイントへの 1 秒あたりのリクエストを徐々に増やしていきました。リクエスト率を上げると、初めはバッチサイズが大きくなり、スループットが向上し、トークンごとのレイテンシも増加するようになります。しかしバッチサイズが限界に達すると、それ以上のリクエストはキューに入れられ、生成される出力トークン数という点で、スループットはプラトー状態になります。

上に示したベンチマークは、プロンプトの長さ分布、サンプリング、バッチ処理の最適化の影響を受けやすいため、高パフォーマンス アテンション カーネルのバリエーションやその他の適応によってさらに改善することが可能です。ベンチマークを試すには、GKE での AI のベンチマーク フレームワークを使用して、GKE AI ワークロードの自動ベンチマークを実行します。

Google Cloud での高パフォーマンス LLM 推論

LLM の大規模でコスト効率の高いサービングを実現するために、Google Cloud は、オーケストレーション、フレームワーク、サービング レイヤ、アクセラレータ設定に応じてユーザーが採用できる、幅広いオプションを提供しています。これらのオプションには、大規模モデル推論のために Cloud TPU GPU の両方をサポートするオーケストレーション レイヤとしての GKE が含まれます。さらに各アクセラレータは、JetStreamJAXPyTorchMaxText)、Hugging Face TGITensorRT-LLMvLLM など、さまざまなサービング レイヤのオプションを提供しています。

アクセラレータ

フレームワーク

オーケストレーション

Google Cloud での AI 最適化推論スタック

Cloud GPU

PyTorch

GKE

vLLMHugging Face TGI

Triton + TensorRT-LLM

Cloud TPU

PyTorchJAXMaxText

GKE

JetStreamHugging Face TGI

まとめ

フレームワークとして JAX PyTorch のどちらを使用するか、またセルフマネージドの柔軟性を提供する GKE オーケストレーションとフルマネージドの統合 AI プラットフォーム(Vertex AI)のどちらを選ぶかにかかわらず、Google Cloud では、Cloud GPU または TPU を通じて Gemma を本番環境で簡単かつ大規模に実行できる、AI 最適化インフラストラクチャを利用できます。Google Cloud は、Gemma モデル、またはその他のオープンソースやカスタムの大規模言語モデルのために、高パフォーマンスでコスト効率の高いトレーニング、ファインチューニング、サービングのオプションを包括的に提供しています。

Cloud TPU v5e v5p を使用した Gemma モデルのトレーニング パフォーマンスでは、MaxText を使ったトレーニングに Gemma のリファレンス実装を使用することで、ベースラインと比較して 1 ドルあたりのパフォーマンスが最大 3 倍向上することが確認されました。また、Cloud TPU での推論に JetStream を使用することで、ベースラインと比較して推論効率が最大 3 倍向上することも確認されました。推論を実行するのが Cloud GPU であっても TPU であっても、Gemma モデル用に高度に最適化されたサービングの実装が用意されています。

開始する方法については、Gemma ドキュメントをご覧ください。Gemma モデルの概要、モデルへのアクセス、Gemma のオンデバイス バリアント、その他あらゆるリソースを確認できます。モデル、アーキテクチャ、評価、安全ベンチマークについて詳しくは、Gemma の技術報告書をご覧ください。また GKE での Gemma に関するドキュメントでは、試用を開始する方法がわかりやすく説明されています。皆様が Gemma を使用して何を作成されるかを楽しみにしております。

ー プロダクト マネージャー Vilobh Meshram

ー シニア プロダクト マネージャー Vaibhav Singh

投稿先