このドキュメントでは、単一ホストの TPU(v2-8、v3-8、v4-8)での Pax の使用について簡単に説明します。
Pax は、JAX 上で ML テストを構成して実行するためのフレームワークです。Pax は、インフラストラクチャ コンポーネントを既存の ML フレームワークと共有し、モジュール化に PraxisPraxis モデリング ライブラリを使用することで、大規模な ML を簡素化することに重点を置いています。
目標
- トレーニング用の TPU リソースを設定する
- 単一ホストの TPU に Pax をインストールする
- Pax を使用して Transformer ベースの SPMD モデルをトレーニングする
準備
次のコマンドを実行して、Cloud TPU プロジェクトを使用するように gcloud
を構成し、単一ホストの TPU で Pax を実行するモデルをトレーニングするために必要なコンポーネントをインストールします。
Google Cloud CLI をインストールする
Google Cloud CLI には、Google Cloud CLI のプロダクトとサービスを操作するためのツールとライブラリが含まれています。まだインストールしていない場合は、Google Cloud CLI のインストールの手順に沿ってインストールします。
gcloud
コマンドを設定する
(gcloud auth list
を実行して、使用可能なアカウントを確認します)。
$ gcloud config set account account
$ gcloud config set project project-id
Cloud TPU API を有効にする
Cloud Shell で、次の gcloud
コマンドを使用して Cloud TPU API を有効にします(Google Cloud コンソールから有効にすることもできます)。
$ gcloud services enable tpu.googleapis.com
次のコマンドを実行して、サービス ID(サービス アカウント)を作成します。
$ gcloud beta services identity create --service tpu.googleapis.com
TPU VM を作成する
Cloud TPU VM では、モデルとコードは TPU VM 上で直接実行されます。TPU VM には、直接 SSH で接続します。TPU VM では、任意のコードの実行、パッケージのインストール、ログの表示、コードのデバッグを直接行えす。
TPU VM は、Google Cloud Shell か、Google Cloud CLI がインストールされているコンピュータ ターミナルから、次のコマンドを実行して作成します。
契約での利用可否に基づいて zone
を設定します。必要に応じて TPU のリージョンとゾーンをご覧ください。
accelerator-type
変数を v2-8、v3-8、または v4-8 に設定します。
version
変数を、v2 と v3 の TPU バージョンの場合は tpu-vm-base
、v4 TPU の場合は tpu-vm-v4-base
に設定します。
$ gcloud compute tpus tpu-vm create tpu-name \ --zone zone \ --accelerator-type accelerator-type \ --version version
Google Cloud TPU VM に接続する
次のコマンドを使用して、TPU VM に SSH 接続します。
$ gcloud compute tpus tpu-vm ssh tpu-name --zone zone
VM にログインすると、シェル プロンプトが username@projectname
から username@vm-name
に変わります。
Google Cloud TPU VM に Pax をインストールする
次のコマンドを使用して、TPU VM に Pax、JAX、libtpu
をインストールします。
(vm)$ python3 -m pip install -U pip \ python3 -m pip install paxml jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
システム チェック
JAX から TPU コアが見えることを確認することで、すべてが正しくインストールされていることをテストします。
(vm)$ python3 -c "import jax; print(jax.device_count())"
TPU コアの数が表示されます。v2-8 または v3-8 を使用している場合は 8、v4-8 を使用している場合は 4 になるはずです。
TPU VM で Pax コードを実行する
これで、任意の Pax コードを実行できるようになりました。lm_cloud の例は、Pax でモデルの実行を開始するのに最適です。たとえば、次のコマンドは、合成データで 2B パラメータ Transformer ベースの SPMD 言語モデルをトレーニングします。
次のコマンドは、SPMD 言語モデルのトレーニング出力を示しています。300 ステップが約 20 分でトレーニングされます。
(vm)$ python3 .local/lib/python3.10/site-packages/paxml/main.py --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps --job_log_dir=job_log_dir
v4-8 スライスでは、出力には次のものが含まれます。
損失とステップ時間
ステップ = step_# loss
= loss でのサマリー テンソル
ステップ = step_# のステップ/秒 x でのサマリー テンソル
クリーンアップ
このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。
TPU VM の使用を終了したら、次の手順に沿ってリソースをクリーンアップします。
Compute Engine インスタンスとの接続を切断していない場合は切断します。
(vm)$ exit
Cloud TPU を削除します。
$ gcloud compute tpus tpu-vm delete tpu-name --zone zone
次のステップ
Cloud TPU の詳細については、以下をご覧ください。