PyTorch XLA ワークロードをプロファイリングする

パフォーマンスの最適化は、効率的な ML モデルを構築するうえで重要な要素です。XProf プロファイリング ツールを使用すると、ML ワークロードのパフォーマンスを測定できます。XProf では、XLA デバイスでのモデルの実行の詳細なトレースをキャプチャできます。これらのトレースは、パフォーマンスのボトルネックの特定、デバイスの使用率の把握、コードの最適化に役立ちます。

このガイドでは、PyTorch XLA スクリプトからトレースをプログラムでキャプチャし、XProf と Tensorboard を使用して可視化するプロセスについて説明します。

トレースをキャプチャする

既存のトレーニング スクリプトに数行のコードを追加することで、トレースをキャプチャできます。トレースをキャプチャするための主なツールは torch_xla.debug.profiler モジュールです。通常、このモジュールはエイリアス xp でインポートされます。

1. プロファイラ サーバーを起動する

トレースをキャプチャする前に、プロファイラ サーバーを起動する必要があります。このサーバーはスクリプトのバックグラウンドで実行され、トレースデータを収集します。メインの実行ブロックの先頭付近で xp.start_server() を呼び出すことで起動できます。

2. トレース期間を定義する

プロファイリングするコードを xp.start_trace() 呼び出しと xp.stop_trace() 呼び出しでラップします。start_trace 関数は、トレース ファイルが保存されるディレクトリのパスを受け取ります。

関連性の最も高いオペレーションをキャプチャするため、メインのトレーニング ループをラップするのが一般的です。

# The directory where the trace files are stored.
log_dir = '/root/logs/'

# Start tracing
xp.start_trace(log_dir)

# ... your training loop or other code to be profiled ...
train_mnist()

# Stop tracing
xp.stop_trace()

3. カスタム トレースラベルを追加する

デフォルトでは、キャプチャされたトレースは低レベルの Pytorch XLA 関数であり、ナビゲートが難しい場合があります。xp.Trace() コンテキスト マネージャーを使用すると、コードの特定のセクションにカスタムラベルを追加できます。これらのラベルは、プロファイラのタイムライン ビューに名前付きブロックとして表示されるため、データ準備、フォワードパス、オプティマイザー ステップなど、特定のオペレーションを簡単に識別できます。

次の例は、トレーニング ステップのさまざまな部分にコンテキストを追加する方法を示しています。

def forward(self, x):
    # This entire block will be labeled 'forward' in the trace
    with xp.Trace('forward'):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 7*7*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# You can also nest context managers for more granular detail
for batch_idx, (data, target) in enumerate(train_loader):
    with torch_xla.step():
        with xp.Trace('train_step_data_prep_and_forward'):
            optimizer.zero_grad()
            data, target = data.to(device), target.to(device)
            output = model(data)

        with xp.Trace('train_step_loss_and_backward'):
            loss = loss_fn(output, target)
            loss.backward()

        with xp.Trace('train_step_optimizer_step_host'):
            optimizer.step()

実証済みの例

次の例は、mnist_xla.py ファイルに基づいて、PyTorch XLA スクリプトからトレースをキャプチャする方法を示しています。

import torch
import torch.optim as optim
from torchvision import datasets, transforms

# PyTorch/XLA specific imports
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp

def train_mnist():
    # ... (model definition and data loading code) ...
    print("Starting training...")
    # ... (training loop as defined in the previous section) ...
    print("Training finished!")

if __name__ == '__main__':
    # 1. Start the profiler server
    server = xp.start_server(9012)

    # 2. Start capturing the trace and define the output directory
    xp.start_trace('/root/logs/')

    # Run the training function that contains custom trace labels
    train_mnist()

    # 3. Stop the trace
    xp.stop_trace()

トレースを可視化する

スクリプトが終了すると、トレース ファイルは指定したディレクトリ(/root/logs/ など)に保存されます。このトレースは、XProf と TensorBoard を使用して可視化できます。

  1. TensorBoard をインストールします。

    pip install tensorboard_plugin_profile tensorboard
  2. TensorBoard を起動します。xp.start_trace() で使用したログ ディレクトリを TensorBoard で指定します。

    tensorboard --logdir /root/logs/
  3. プロファイルを表示します。TensorBoard から提供された URL(通常は http://localhost:6006)をウェブブラウザで開きます。[PyTorch XLA - Profile] タブに移動して、インタラクティブ トレースを表示します。作成したカスタムラベルを確認し、モデルのさまざまな部分の実行時間を分析できます。

Google Cloud でワークロードを実行する場合は、cloud-diagnostics-xprof ツールをおすすめします。これは、Tensorboard と XProf を実行する VM を使用してプロファイルの収集と表示を効率的に行います。