PyTorch デベロッパー向け JAX 基礎ガイド
Anfal Siddiqui
Machine Learning Engineer, Cloud AI
※この投稿は米国時間 2025 年 1 月 7 日に、Google Cloud blog に投稿されたものの抄訳です。
PyTorch を使われている方なら、JAX のメリットについて耳にしたことがあると思います。たとえば、その高いパフォーマンス、エレガントな関数型プログラミング アプローチ、強力な組み込みの並列計算サポートなどです。しかし、それと同時に、JAX を使い始めるときに利用できる、わかりやすく進めやすいチュートリアルを見つけるのは難しいかもしれません。JAX の新しいコンセプトを、馴染みのある PyTorch の構成要素と結び付け、JAX の基礎への理解を促すようなものです。そこでこのたび、そのようなチュートリアルを作成しました。
このチュートリアルでは、PyTorch ユーザーの視点から JAX エコシステムの基礎を探っていきます。タイタニック号の沈没事故で生き残る乗客を予測するという代表的な ML タスクを取り上げ、簡単なニューラル ネットワークを両方のフレームワークでトレーニングします。その過程で、モデルの定義やインスタンス化、トレーニングなど、JAX の多くの要素と、PyTorch の同等の要素の対応を示しながら、JAX を紹介していきます。
付属のノートブックで、コード例の全文をご確認いただけます(https://www.kaggle.com/code/anfalatgoogle/pytorch-developer-s-guide-to-jax-fundamentals)。
JAX におけるモジュール性
PyTorch ユーザーから見て、高度にモジュール化された JAX のエコシステムは、使い慣れた PyTorch とはかなり違うと最初は感じるかもしれません。JAX は、自動微分をサポートする、高パフォーマンスの数値計算ライブラリを目指しています。PyTorch とは異なり、ニューラル ネットワーク、オプティマイザなどの定義のサポートを明示的に組み込もうとはせず、柔軟性を念頭に設計されているため、任意のフレームワークを機能に追加できるようになっています。
このチュートリアルでは、Flax ニューラル ネットワーク ライブラリと Optax 最適化ライブラリを使用します。いずれも一般的に使われ、サポートも充実したライブラリです。先に、新しい Flax NNX API でニューラル ネットワークをトレーニングする方法と、PyTorch との類似性を確認します。その後、同じことを、現在も広く使われている以前の Linen API で行う方法を示します。
関数型プログラミング
チュートリアルを始める前に、PyTorch やその他のフレームワークで使われているオブジェクト指向プログラミングではなく、関数型プログラミングが JAX で使われている理由を説明します。簡単に言うと、関数型プログラミングでは、状態を変更できず、副作用を生じさせない、つまり同じ入力に対して常に同じ出力を生成する、純粋関数を主に使います。JAX では、コンポーズ可能な関数と変更できない配列の多用という形でこれが現れています。
純粋関数と関数型プログラミングの予測しやすいという性質が、JAX の多くのメリットを可能にしています。たとえば、ジャストインタイム(JIT)コンパイルにより、XLA コンパイラは GPU や TPU のコードを大幅に最適化し、速度を大きく向上できます。また、JAX では処理のシャーディングや並列化が非常に簡単になっています。詳細については、JAX の公式チュートリアルをご覧ください。
関数型プログラミングに慣れていなくても心配ご無用です。これから見ていくように、Flax NNX ではこうしたことが標準の Python に似たイディオムの背後に隠されています。
データ読み込み
JAX におけるデータ読み込みは非常にシンプルで、PyTorch と同じように行います。PyTorch のデータセット / データローダを使用して、簡単な collate_fn
で、JAX のすべての計算を支える NumPy に似た配列への変換を行うことができます。
モデルの定義
Flax NNX API におけるニューラル ネットワークの定義は、PyTorch の場合とよく似ています。ここでは、簡単な 2 層の多層パーセプトロンを両方のフレームワークで定義します。まずは PyTorch からです。
NNX のモデル定義は上記の PyTorch コードとよく似ています。いずれの場合も __init__
を使用してモデルの層を定義します。__call__
は forward
に対応します。
モデルの初期化と使用
NNX でのモデルの初期化は PyTorch とほぼ同じです。いずれのフレームワークでも、モデルクラスのインスタンスをインスタンス化するときに、モデルのパラメータが(遅延初期化ではなく)早期初期化され、インスタンス自体と関連付けられます。NNX の唯一の違いは、モデルのインスタンス化のときに疑似乱数生成器(PRNG)キーを渡す必要があることです。NNX では、JAX の関数型という性質に沿って、暗黙的なグローバル ランダム状態が回避されるため、PRNG キーを明示的に渡す必要があります。これにより、PRNG の生成が容易に再現可能、並列化可能、ベクトル化可能になります。詳細については、JAX のドキュメントをご覧ください。
モデルを実際に使ったデータ群の処理は、2 つのフレームワークで同じです。
トレーニングのステップと誤差逆伝播
トレーニング ループに関しては、PyTorch と Flax NNX とで、いくつか重要な違いがあります。それを確認するために、NNX の完全なトレーニング ループを順を追って組み立ててみましょう。
設定
いずれのフレームワークでも、オプティマイザを作成し、独自の最適化アルゴリズムを柔軟に指定できます。PyTorch ではモデルのパラメータを渡す必要がありますが、Flax NNX ではモデルを直接渡すだけで、背後にある Optax オプティマイザですべてのインタラクションが処理されます。
フォワード + バックワード パス
PyTorch と JAX の最大の違いは、おそらく完全なフォワード / バックワード パスを行う方法です。PyTorch では loss.backward()
で勾配を計算します。このとき AutoGrad がトリガーされ、loss
の計算グラフに沿って勾配が計算されます。
これに対して、JAX の自動微分はより数学に近く、関数の勾配があります。具体的には、nnx.value_and_grad
/ nnx.grad
で関数 loss_fn
を受け取り、関数 grad_fn
を返します。その後、grad_fn
自体が、入力に応じて、loss_fn
の出力の勾配を返します。
この例では、PyTorch と同じことを loss_fn
で行っています。最初にフォワードパスから logits
を取得してから、おなじみの loss
を計算しています。その後 grad_fn
が、model
のパラメータに応じて、loss
の勾配を計算しています。数学的には、返される grads
は ∂J/∂θ
です。これは PyTorch の背後で行われている処理と同じです。PyTorch では loss.backward()
の実行時に勾配がテンソルの .grad
属性に格納されるのに対して、JAX / Flax NNX では、状態を変更しないという関数型アプローチに従い、勾配が直接返されています。
オプティマイザのステップ
PyTorch では、optimizer.step()
により、勾配を使用して重みがインプレースで更新されます。NNX でも重みがインプレースで更新されますが、バックワード パスで計算した grads を直接渡す必要があります。これは PyTorch と同じ最適化のステップですが、JAX の基礎にある関数型という性質に沿って、PyTorch よりもわずかに明示的になっています。
完全なトレーニング ループ
これで、JAX / Flax NNX で完全なトレーニング ループを構築するために必要なものがすべて揃いました。参考として、よく知っている PyTorch のループを先に確認しましょう。
次に、NNX の完全なトレーニング ループです。
最大のポイントは、トレーニング ループが PyTorch と JAX / Flax NNX とでよく似ているということです。相違点のほとんどは、オブジェクト指向プログラミングか関数型プログラミングかの違いに行き着きます。関数型プログラミングと、関数の勾配を考えるという点については、慣れるまでに少し時間がかかるかもしれませんが、慣れてしまえば JIT コンパイルや自動並列化といった前述の JAX のメリットを享受できます。たとえば、@nnx.jit
のアノテーションを上記の関数に追加するだけで、P100 GPU で 500 エポックのモデルのトレーニング時間を 6.25 分からわずか 1.8 分に短縮できることが Kaggle で確認されています。他の CPU や TPU、NVIDIA 以外の GPU でも同程度のスピードアップを見込めることがわかっています。
Flax Linen のリファレンス
前述のとおり、JAX のエコシステムは非常に柔軟であり、任意のフレームワークを使用できるようになっています。新規のユーザーには NNX がおすすめのソリューションですが、Flax Linen API も、MaxText や MaxDiffusion などの強力なフレームワークで現在も広く使用されています。NNX はより Python に似て、状態管理の複雑さが隠されていますが、Linen は純粋関数型プログラミングにより忠実になっています。
JAX のエコシステムに加わる場合、この両方に慣れておくことが大きなメリットになります。その一助となるよう、以下では NNX のほとんどのコードを Linen で再現し、主な違いをコメントとして入れます。
次のステップ
このブログ投稿で得た JAX / Flax の知識により、独自のニューラル ネットワークを記述する準備ができたと思います。Google Colab か Kaggle を使用すれば、すぐに始められます。Kaggle で課題を探し、Flax NNX を使ってモデルを新規に記述するか、MaxText を使って大規模言語モデル(LLM)のトレーニングを開始するなど、可能性は無限にあります。
今回取り上げたのは、JAX と Flax の一部にすぎません。JIT、自動ベクトル化、カスタム勾配などの詳細については、JAX と Flax のドキュメントをご覧ください。
-Cloud AI、ML エンジニア Anfal Siddiqui