JAX を使用して Cloud TPU VM で計算を実行する

このドキュメントでは、JAX と Cloud TPU の使用について簡単に説明します。

このクイックスタートを使用する前に、Google Cloud Platform アカウントを作成し、Google Cloud CLI をインストールして gcloud コマンドを構成する必要があります。詳細しくは、アカウントと Cloud TPU プロジェクトを設定するをご覧ください。

Google Cloud CLI をインストールする

Google Cloud CLI には、Google Cloud プロダクトとサービスを操作するためのツールとライブラリが含まれています。詳細については、Google Cloud CLI のインストールをご覧ください。

gcloud コマンドを設定する

次のコマンドを実行して、Google Cloud プロジェクトを使用するように gcloud を構成し、TPU VM プレビュー版の必要なコンポーネントをインストールします。

  $ gcloud config set account your-email-account
  $ gcloud config set project your-project-id

Cloud TPU API を有効にする

  1. Cloud Shell で、次の gcloud コマンドを使用して Cloud TPU API を有効にします(Google Cloud Console から有効にすることもできます)。

    $ gcloud services enable tpu.googleapis.com
  2. 次のコマンドを実行して、サービス ID を作成します。

    $ gcloud beta services identity create --service tpu.googleapis.com

gcloud を使用した Cloud TPU VM の作成

Cloud TPU VM では、モデルとコードは TPU ホストマシン上で直接実行されます。TPU ホストには、直接 SSH で接続します。TPU ホストでは、任意のコードの実行、パッケージのインストール、ログの表示、コードのデバッグを直接行えす。

  1. TPU VM は、Google Cloud Shell か、Google Cloud CLI がインストールされているコンピュータ ターミナルから、次のコマンドを実行して作成します。

    (vm)$ gcloud compute tpus tpu-vm create tpu-name \
    --zone=us-central2-b \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base

    必須項目

    zone
    Cloud TPU を作成するゾーン
    accelerator-type
    アクセラレータ タイプでは、作成する Cloud TPU のバージョンとサイズを指定します。TPU のバージョンごとにサポートされているアクセラレータ タイプの詳細については、TPU のバージョンをご覧ください。
    version
    Cloud TPU ソフトウェアのバージョン。すべての TPU タイプで tpu-ubuntu2204-base が使用されます。

Cloud TPU VM に接続する

次のコマンドを使用して、TPU VM に SSH 接続します。

$ gcloud compute tpus tpu-vm ssh tpu-name --zone=us-central2-b

必須フィールド

tpu_name
接続する TPU VM の名前。
zone
Cloud TPU を作成したゾーン

Cloud TPU VM に JAX をインストールする

(vm)$ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

システム チェック

JAX が TPU にアクセスし、基本オペレーションを実行できることを確認します。

Python 3 インタプリタを起動する

(vm)$ python3
>>> import jax

使用可能な TPU コアの数を表示する

>>> jax.device_count()

TPU コアの数が表示されます。v4 TPU を使用している場合は、4 にします。v2 または v3 TPU を使用している場合は、8 にします。

簡単な計算を実行する

>>> jax.numpy.add(1, 1)

numpy の足し算の結果が表示されます。

コマンドからの出力は次のようになります。

Array(2, dtype=int32, weak_type=true)

Python インタプリタを終了する

>>> exit()

TPU VM で JAX コードを実行する

これで、任意の JAX コードを実行できるようになりました。Flax の例は、JAX で標準の ML モデルの実行を開始するのに適しています。たとえば、基本的な MNIST 畳み込みネットワークをトレーニングする場合は、以下に従います。

  1. Flax サンプルの依存関係をインストールする

    (vm)$ pip install --upgrade clu
    (vm)$ pip install tensorflow
    (vm)$ pip install tensorflow_datasets
  2. FLAX のインストール

    (vm)$ git clone https://github.com/google/flax.git
    (vm)$ pip install --user flax
  3. Flax MNIST トレーニング スクリプトを実行します。

    (vm)$ cd flax/examples/mnist
    (vm)$ python3 main.py --workdir=/tmp/mnist \
    --config=configs/default.py \
    --config.learning_rate=0.05 \
    --config.num_epochs=5

このスクリプトはデータセットをダウンロードして、トレーニングを開始します。スクリプトの出力は、次のようになります。

  0214 18:00:50.660087 140369022753856 train.py:146] epoch:  1, train_loss: 0.2421, train_accuracy: 92.97, test_loss: 0.0615, test_accuracy: 97.88
  I0214 18:00:52.015867 140369022753856 train.py:146] epoch:  2, train_loss: 0.0594, train_accuracy: 98.16, test_loss: 0.0412, test_accuracy: 98.72
  I0214 18:00:53.377511 140369022753856 train.py:146] epoch:  3, train_loss: 0.0418, train_accuracy: 98.72, test_loss: 0.0296, test_accuracy: 99.04
  I0214 18:00:54.727168 140369022753856 train.py:146] epoch:  4, train_loss: 0.0305, train_accuracy: 99.06, test_loss: 0.0257, test_accuracy: 99.15
  I0214 18:00:56.082807 140369022753856 train.py:146] epoch:  5, train_loss: 0.0252, train_accuracy: 99.20, test_loss: 0.0263, test_accuracy: 99.18

クリーンアップ

TPU VM の使用を終了したら、次の手順に沿ってリソースをクリーンアップします。

  1. Compute Engine インスタンスとの接続を切断していない場合は切断します。

    (vm)$ exit
  2. Cloud TPU を削除します。

    $ gcloud compute tpus tpu-vm delete tpu-name \
      --zone=us-central2-b
  3. 次のコマンドを実行して、リソースが削除されたことを確認します。TPU がリストに表示されないことを確認します。削除には数分かかることがあります。

    $ gcloud compute tpus tpu-vm list \
      --zone=us-central2-b

パフォーマンスに関する注意

ここでは、特に JAX での TPU の使用に関連する重要事項をいくつか説明します。

パディング

TPU のパフォーマンスが低下する最も一般的な原因の 1 つは、間違ったパディングが導入されることです。

  • Cloud TPU 内の配列はタイル状になっています。これは、ディメンションの 1 つを 8 の倍数にパディングし、別のディメンションを 128 の倍数にパディングする必要があります。
  • 行列乗算ユニットは、パディングの必要性を最小限に抑える大規模な行列のペアで最もパフォーマンスを発揮します。

bfloat16 dtype

デフォルトでは、TPU 上の JAX の行列乗算は、float32 累積と bfloat16 を使用します。これは、関連する jax.numpy 関数の呼び出し(matmul、dot、einsum など)の precision 引数で制御できます。具体的には次のようにします。

  • precision=jax.lax.Precision.DEFAULT: bfloat16 の混合精度を使用する(最速)
  • precision=jax.lax.Precision.HIGH: 複数の MXU パスを使用して精度を高める
  • precision=jax.lax.Precision.HIGHEST: さらに多くの MXU パスを使用して、フル精度の float32 を実現する

また、JAX では bfloat16 dtype も追加されます。これは、配列を明示的に bfloat16 にキャストして使用できます(例: jax.numpy.array(x, dtype=jax.numpy.bfloat16))。

Colab で JAX を実行する

Colab ノートブックで JAX コードを実行すると、Colab によって以前の TPU ノードが自動的に作成されます。TPU ノードには異なるアーキテクチャがあります。詳細については、システム アーキテクチャをご覧ください。

次のステップ

Cloud TPU の詳細については、以下をご覧ください。