Pax を使用して単一ホストの TPU でトレーニングする


このドキュメントでは、単一ホストの TPU(v2-8、v3-8、v4-8)での Pax の使用について簡単に説明します。

Pax は、JAX 上で ML テストを構成して実行するためのフレームワークです。Pax は、インフラストラクチャ コンポーネントを既存の ML フレームワークと共有し、モジュール化に PraxisPraxis モデリング ライブラリを使用することで、大規模な ML を簡素化することに重点を置いています。

目標

  • トレーニング用の TPU リソースを設定する
  • 単一ホストの TPU に Pax をインストールする
  • Pax を使用して Transformer ベースの SPMD モデルをトレーニングする

準備

次のコマンドを実行して、Cloud TPU プロジェクトを使用するように gcloud を構成し、単一ホストの TPU で Pax を実行するモデルをトレーニングするために必要なコンポーネントをインストールします。

Google Cloud CLI をインストールする

Google Cloud CLI には、Google Cloud CLI のプロダクトとサービスを操作するためのツールとライブラリが含まれています。まだインストールしていない場合は、Google Cloud CLI のインストールの手順に沿ってインストールします。

gcloud コマンドを設定する

gcloud auth list を実行して、使用可能なアカウントを確認します)。

$ gcloud config set account account

$ gcloud config set project project-id

Cloud TPU API を有効にする

Cloud Shell で、次の gcloud コマンドを使用して Cloud TPU API を有効にします(Google Cloud コンソールから有効にすることもできます)。

$ gcloud services enable tpu.googleapis.com

次のコマンドを実行して、サービス ID(サービス アカウント)を作成します。

$ gcloud beta services identity create --service tpu.googleapis.com

TPU VM を作成する

Cloud TPU VM では、モデルとコードは TPU VM 上で直接実行されます。TPU VM には、直接 SSH で接続します。TPU VM では、任意のコードの実行、パッケージのインストール、ログの表示、コードのデバッグを直接行えす。

TPU VM は、Google Cloud Shell か、Google Cloud CLI がインストールされているコンピュータ ターミナルから、次のコマンドを実行して作成します。

契約での利用可否に基づいて zone を設定します。必要に応じて TPU のリージョンとゾーンをご覧ください。

accelerator-type 変数を v2-8、v3-8、または v4-8 に設定します。

version 変数を、v2 と v3 の TPU バージョンの場合は tpu-vm-base、v4 TPU の場合は tpu-vm-v4-base に設定します。

$ gcloud compute tpus tpu-vm create tpu-name \
--zone zone \
--accelerator-type accelerator-type \
--version version

Google Cloud TPU VM に接続する

次のコマンドを使用して、TPU VM に SSH 接続します。

$ gcloud compute tpus tpu-vm ssh tpu-name --zone zone

VM にログインすると、シェル プロンプトが username@projectname から username@vm-name に変わります。

Google Cloud TPU VM に Pax をインストールする

次のコマンドを使用して、TPU VM に Pax、JAX、libtpu をインストールします。

(vm)$ python3 -m pip install -U pip \
python3 -m pip install paxml jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

システム チェック

JAX から TPU コアが見えることを確認することで、すべてが正しくインストールされていることをテストします。

(vm)$ python3 -c "import jax; print(jax.device_count())"

TPU コアの数が表示されます。v2-8 または v3-8 を使用している場合は 8、v4-8 を使用している場合は 4 になるはずです。

TPU VM で Pax コードを実行する

これで、任意の Pax コードを実行できるようになりました。lm_cloudは、Pax でモデルの実行を開始するのに最適です。たとえば、次のコマンドは、合成データで 2B パラメータ Transformer ベースの SPMD 言語モデルをトレーニングします。

次のコマンドは、SPMD 言語モデルのトレーニング出力を示しています。300 ステップが約 20 分でトレーニングされます。

(vm)$ python3 .local/lib/python3.10/site-packages/paxml/main.py  --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps --job_log_dir=job_log_dir

v4-8 スライスでは、出力には次のものが含まれます。

損失とステップ時間

ステップ = step_# loss = loss でのサマリー テンソル
ステップ = step_# のステップ/秒 x でのサマリー テンソル

クリーンアップ

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

TPU VM の使用を終了したら、次の手順に沿ってリソースをクリーンアップします。

Compute Engine インスタンスとの接続を切断していない場合は切断します。

(vm)$ exit

Cloud TPU を削除します。

$ gcloud compute tpus tpu-vm delete tpu-name  --zone zone

次のステップ

Cloud TPU の詳細については、以下をご覧ください。