PyTorch / XLA 2.4 を発表: Pallas とデベロッパー エクスペリエンスの向上、そして「イーガーモード」
Bhavya Bahl
Software Engineer
Duncan Campbell
Developer Advocate, Google Cloud
※この投稿は米国時間 2024 年 7 月 31 日に、Google Cloud blog に投稿されたものの抄訳です。
ディープ ラーニングの研究者や実務担当者に対し、オープンソースの PyTorch ML ライブラリと XLA ML コンパイラは、柔軟性の高い強力なモデル トレーニング、ファインチューニング、サービングをもたらします。このたび、PyTorch / XLA チームは PyTorch / XLA 2.4 をリリースいたしました。このリリースでは、前回のリリースに、デベロッパーの課題に対処するための注目すべき改良がいくつか加えられています。ここでは、PyTorch / XLA をさらに使いやすくする最新機能を一部ご紹介します。
-
TPU と GPU の両方をサポートするカスタム カーネル言語 Pallas の改善
-
新しい API 呼び出し
-
試験運用版「イーガーモード」の導入
-
新しい TPU コマンドライン インターフェース
では、ユーザビリティの向上について見ていきましょう。
Pallas の機能強化
XLA コンパイラ自体も既存のモデルを最適化できますが、モデルの作成者がカスタム カーネル コードを書くことで、より高いパフォーマンスを得られる場合があります。カスタム カーネル言語の Pallas は TPU と GPU の両方に対応しているため、C++ のようなより複雑な低水準言語を使わなくても、ハードウェアに近い Python で、よりパフォーマンスの高いコードを記述できます。Pallas は Triton ライブラリと似ていますが、TPU と GPU の両方で動作するため、ある ML アクセラレータから別のアクセラレータにモデルを簡単に移植できます。
最新の PyTorch / XLA 2.4 リリースでは、Pallas の機能とユーザー エクスペリエンスを強化するアップデートが導入されています。
-
PyTorch の autograd(自動勾配計算)と完全に統合された FlashAttention のサポートを追加
-
推論のために PagedAttention を組み込みでサポート
-
グループ行列乗算のために MegaBlocks のブロック スパース カーネルを Autograd 関数としてサポートし、誤差逆伝播を手動で行う必要性を排除
API の変更点
PyTorch / XLA 2.4 では、既存の PyTorch ワークフローへのインテグレーションを容易にする新しい呼び出しがいくつか導入されています。以下に例を挙げます。
現在は上記のような呼び出しが可能になりましたが、従来は以下のように呼び出す必要がありました。
また、xm.mark_step() を呼び出す代わりに torch_xla.sync() を呼び出せるようになりました。こうした改善により、コードを PyTorch / XLA に変換しやすくなり、デベロッパー ワークフローが向上しました。API 呼び出しのその他の変更点については、リリースノートをご確認ください。
試験運用版のイーガーモード
PyTorch / XLA をしばらくお使いなら、モデルが「遅延実行」されていることはご存じでしょう。これは、PyTorch / XLA がオペレーションのコンピューティング グラフを作成してから、XLA デバイスのターゲット ハードウェアで実行するためにモデルを送信しているということです。新しいイーガーモードでは、オペレーションはコンパイルされ、ターゲット ハードウェア上ですぐに実行されます。
ただし、この機能には、TPU 自体に真のイーガーモードがないという問題点があります。デフォルトでは、個々の命令はすぐに TPU に送信されません。TPU では、個々の PyTorch オペレーションの後に「mark_step」呼び出しを追加してコンパイルと実行を強制することにより、この機能を実現しています。これによりイーガーモードの機能は実現していますが、これはネイティブの機能ではなくエミュレーションとして実現したものです。
このリリースのイーガーモードは、本番環境ではなくローカル環境で実行することを意図しています。イーガーモードを使用することで、多くの本番環境システムのように多数のデバイスにデプロイしなくても、お手元のマシンでローカルにモデルをデバッグすることが容易になると期待しています。
Cloud TPU 情報コマンドライン インターフェース
Nvidia の GPU を使ったことがある方は nvidia-smi ツールをよくご存じかもしれません。このツールを使って、GPU ワークロードのデバッグ、使われているコアの識別、特定のワークロードのメモリ消費量の確認を行うことができます。現在は Cloud TPU にも同様のコマンドライン ユーティリティがあり、tpu-info で使用状況やデバイス情報を簡単に表示できます。以下に出力の例を示します。
今すぐ PyTorch / XLA 2.4 の利用を開始しましょう
PyTorch / XLA 2.4 では API も一部変更されていますが、この最新バージョンの最良の点は、ご利用の既存コードとも互換性があり、新しい API 呼び出しによって今後の開発プロセスが容易になる点です。この機会に、最新バージョンをぜひお試しください。詳細については、プロジェクトの GitHub リポジトリをご覧ください。