bfloat16 の数値形式

精度の低い浮動小数点を使用して、精度を失わずに収束までの時間を短縮する方法が一般的です。TPU は、マトリックス演算を実行するときに bfloat16 数値形式を使用します。行列乗算演算は bfloat16 値に対して実行され、累積は IEEE float32 値に対して実行されます。

bfloat16 は、1 つの符号ビット、8 つの指数ビット、7 つの仮数ビットで構成されている機械学習用のカスタム 16 ビット浮動小数点形式です。  次の図は、float32: IEEE の単精度、float16: IEEE の半精度、bfloat16 の 3 つの浮動小数点形式の内部を示しています。

画像

bfloat16 のダイナミック レンジは float32 と同じであり、メモリ量の半分を消費します。研究によると、bfloat16 形式は float32 形式と同様にコンバージして、パフォーマンスの向上とメモリ使用量の削減が実現されます。

bfloat16 の選択

Google のハードウェア チームは、float32 からの移行コストを最小限に抑えながら、ディープ ラーニング モデルを正確にトレーニングする能力を維持しながら、ハードウェアの効率を向上させるために Cloud TPU 用に bfloat16 を選択しました。ハードウェア乗数の物理サイズは、仮数幅の 2 乗にスケーリングされます。仮数ビットが FP16 よりも少ない場合、bfloat16 乗算器はシリコンの一般的な FP16 乗算器の約半分のサイズであり、float32 乗算器の 8 分の 1 の大きさになります。

ニューラル ネットワークは、仮数よりも指数のサイズにかなり敏感です。アンダーフロー、オーバーフロー、NaN で同じ動作になるように、bfloat16 の指数サイズは float32 と同じです。bfloat16 は、float32 とは異なる方法で非正規化を処理し、ゼロにフラッシュします。損失スケーリングのような特別な処理が必要な float16 とは異なり、bfloat16 はディープ ニューラル ネットワークをトレーニングして実行する際の float32 の一時的な代替となるものです。

混合精度トレーニング

ディープ ニューラル ネットワーク内のほとんどの計算では、たとえば各数値の 18 桁目を計算する必要はありません。ネットワークは精度の低い近似値を使用して、同じ精度でタスクを達成できます。一部のモデルでは、低い精度で高い精度に到達することもできます。

Cloud TPU をプログラミングすると、TPU ランタイムによって自動形式変換が行われます。 値は、XLA コンパイラによって float32bfloat16 の間でシームレスに変換されます。これにより、デフォルトで float32 形式を使用してモデルを作成し、コードを変更せずに、パフォーマンス上の利点を得ることができます。

モデルのポータビリティ

TPU ハードウェアはこれらの値を bfloat16 に自動的にキャストできるため、モデル内のパラメータの値と有効化は完全な 32 ビット形式で保存できます。Cloud TPU でトレーニングされたモデルから取得したチェックポイントは、大規模な手動変換なしで、他のハードウェア プラットフォームにデプロイできます(CPU や GPU での推論や微調整など)。

bfloat16 によるパフォーマンスの改善

TPU での自動形式変換により、数値の精度について考慮する必要がなくなりますが、値を明示的に bfloat16 にキャストすることで、さらにパフォーマンスを向上させることができます。値が明示的に bfloat16 にキャストされる理由は 2 つあります。

  1. 値を bfloat16 形式で格納すると、オンチップ メモリを節約し、Cloud TPU でより大きなモデルのトレーニングや、より大きなバッチサイズの使用が可能になります。

  2. 一部の演算にはメモリ帯域幅の制限があります。つまり、メモリからデータを読み込むのにかかる時間が、計算の実行に費やす全体的な時間を遅くする可能性があります。これらの演算のオペランドと出力を bfloat16 形式で格納すると、転送する必要があるデータ量が減り、全体的な速度が向上します。

開始するには、Cloud TPU 用に最適化された bfloat16 対応のリファレンス モデルのいずれかでハンズオン経験を行うことをおすすめします。Google のパフォーマンス ガイド、プロファイリング ツールガイドトラブルシューティング ガイドに、マシンの作成と最適化に役立つ詳細な技術情報が記載されています。ご自身で機械学習モデルを作成して最適化してください。