コンテンツに移動
デベロッパー

最新の生成 AI に確率的丸めが不可欠な理由

2026年1月8日
https://storage.googleapis.com/gweb-cloudblog-publish/images/karl_december25_hero.max-1500x1500.jpg
Karl Weinmeister

Director, Developer Relations

※この投稿は米国時間 2025 年 12 月 20 日に、Google Cloud blog に投稿されたものの抄訳です。

1940 年代のコンピュータ草創期に、数学者たちは丸め誤差の動作に関する仮定に欠陥があることに気づきました。固定小数点演算では、誤差が相殺されずに累積され、計算の精度が低下するのです。数年後、「ランダム丸め」が提案されました。これは、余りに比例するランダムな確率に基づいて切り上げまたは切り下げを行う手法です。

生成 AI の時代を迎えた現在、私たちは数値演算の新たな課題に直面しています。メモリのボトルネックを克服するため、業界は FP8 や新たな 4 ビット標準などの低精度形式に移行しつつあります。しかし、低精度のトレーニングは不安定です。標準の丸めでは、学習を推進する僅かな勾配の更新が破棄されるため、モデルのトレーニングが停滞します。1950 年代からあるこの手法は、現在では確率的丸めとして知られており、シグナルを失わずに大規模なモデルをトレーニングすることを可能にします。この記事では、JAX や Qwix などのフレームワークが最新の Google Cloud ハードウェアにこの手法を適用して、低精度のトレーニングを可能にする仕組みについて説明します。

勾配消失が発生するタイミング

低精度トレーニングの課題は、更新の消失です。これは、小さな勾配の更新が最近接丸め(RTN)演算によって規則正しくゼロに丸められるときに発生します。たとえば、大きな重みが 100.0 で学習の更新が 0.001 の場合、低精度形式では 100.001100.0 と同じものとして登録される可能性があります。更新が事実上消失するため、学習が停滞します。

数値をスイミング プールに例えてみましょう。このプールでは、ガロン単位の整数で水量が記録されます。小さじ 1 杯の水を加えると、システムは合計水量を計算して最も近い整数に切り下げます。これにより、追加分は事実上削除されます。小さじ 1 杯の水を 10 億回注いでも、記録される水位はまったく上昇しません。

確率による精度

確率的丸め(SR)は、決定論的丸めルールを確率に置き換えることで、この問題を解決します。たとえば、1.4 を常に 1 に切り下げるのではなく、確率 60% で 1 に切り下げ、確率 40% で 2 に切り上げます。

これは、数学的には区間 [⌊x⌋,⌊x⌋+1] の値 x に対して次のように定義されます。

https://storage.googleapis.com/gweb-cloudblog-publish/images/equation.max-600x600.png

定義プロパティは、SR の期待値に偏りがないことです。

  • 確率的丸め: E[SR(x)] = x

  • 最近接丸め: E[RTN(x)] ≠ x

違いを確認するため、1.4 の例をもう一度見てみましょう。RTN は決定論的であり、毎回 1 を出力します。分散は 0 です。安定はしているものの、出力は常に誤っています。これに対して、SR は 1、1、2、1、2...のようなノイズの多いストリームを出力します。平均値(1.4)は正しいものの、個々の値は変動します。

分散の式を使用して、偏りがゼロであることの「コスト」を定量化できます。

Var(SR(x))=p(1-p) where p=x-⌊x⌋

これに対して RTN は、分散はゼロですが、誤差が急速に蓄積するという問題があります。N 回の演算の合計では、定誤差が直線的に増加する可能性があります(O(N))。わずかな数値であっても常に切り捨て続けると、急速に誤差が蓄積されます。

SR の動作はこれとは異なります。誤差はランダムで偏りがないため、相殺される傾向があります。こうした「ランダム ウォーク」では、誤差の合計は演算回数の平方根に比例して増加します(O(√N))。

確率的丸めではノイズが発生しますが、トレードオフは多くの場合無害です。ディープ ラーニングでは、この追加された分散はドロップアウトや正規化と同様に一種の暗黙的な正則化として機能することが多く、モデルが浅い局所的な最小値を回避して、より適切に一般化するのに役立ちます。

Google Cloud での実装

Google Cloud は、Cloud TPUNVIDIA Blackwell GPU など、最新世代の AI アクセラレータを通じて確率的丸めをサポートしています。これらのアクセラレータは、AI 向けに最適化された Google Kubernetes Engine クラスタでも使用できます。

TPU でのネイティブ サポート

Google の TPU アーキテクチャでは、Matrix Multiply Unit(MXU)における確率的丸めのネイティブ ハードウェア サポートが提供されます。これにより、モデルのパフォーマンスを大幅に低下させることなく、INT4、INT8、FP8 などの低精度形式でトレーニングを実行できます。

Google の Qwix ライブラリも使用できます。これは、トレーニング(QAT)とトレーニング後の量子化(PTQ)の両方をサポートする JAX 向け量子化ツールキットです。INT8 でモデルを量子化するようにこのライブラリを構成し、バックワード パスに対する確率的丸めを明示的に有効化することで勾配消失を防ぐ方法を以下に示します。

lang-py

lang-py
読み込んでいます...

Qwix は下位レベルのハードウェア命令の複雑さを抽象化します。これにより、シンプルな構成で量子化ロジックをモデルのグラフに直接挿入できます。

NVIDIA Blackwell および A4X VM

Google Cloud で NVIDIA GPU を使用する場合も同様です。NVIDIA GB200 NVL72 システムを搭載した業界初のクラウド インスタンスである A4X VM をデプロイできます。この VM は、72 個の Blackwell GPU を、単一のスーパーコンピューティング ユニットである AI Hypercomputer に接続します。

Blackwell は、NVFP4(ブロック スケーリング戦略を利用する 4 ビット浮動小数点形式)に対するネイティブ ハードウェア サポートを提供します。精度を維持するため、NVFP4BlockScaling レシピにより、偏りを回避する確率的丸めが勾配に自動的に適用され、さらに他の高度なスケーリング手法も適用されます。

このレシピを使用してレイヤを te.autocast でラップした場合、ライブラリはバックワード パスに対して以下のモードを使用します。

lang-py

lang-py
読み込んでいます...

このコンテキスト マネージャーを入力するだけで、A4X の GB200 GPU は 4 ビット精度で行列乗算を実行するとともに、バックワード パスに確率的丸めを使用します。これにより、収束に悪影響を与えることなく、以前の世代の最大 4 倍のトレーニング パフォーマンスを実現できます。

本番環境のベスト プラクティス

本番環境で SR を効果的に実装するには、まず、確率的丸めがトレーニング専用の手法であることに留意してください。SR は非決定論的であるため、一貫性のある出力が必要とされる推論ワークロードでは、標準の最近接丸めを使用する必要があります。

次に、SR は発散をデバッグするためのツールとして使用してください。低精度のトレーニングが安定しない場合は、勾配ノルムを確認してください。勾配消失の場合は SR の有効化で解決する可能性がありますが、勾配爆発の場合は別の問題が考えられます。

最後に、再現性を慎重に管理してください。SR は乱数の生成に依存するので、ビット単位の再現性の管理は困難です。jax.random.key(0) などを使用して常にグローバルなランダムシードを設定し、トレーニング実行で「決定論的なランダム性」が現れるようにします。これにより、内部の確率的操作にかかわらず、毎回同じ結果が生成されます。

確率的丸めは、低精度の演算のノイズを学習のシグナルに変換します。この 1950 年代からある数値処理手法は、A4X VMIronwood TPU で限界を押し広げている場合でも、次世代の AI パフォーマンスの可能性を引き出す鍵となります。

LinkedInXBluesky でつながって、AI インフラストラクチャの過去、現在、未来についてもっと話し合いましょう。

- デベロッパー リレーションズ担当ディレクター、Karl Weinmeister

投稿先