v6e TPU での MaxDiffusion 推論
このチュートリアルでは、TPU v6e で MaxDiffusion モデルをサービングする方法について説明します。このチュートリアルでは、Stable Diffusion XL モデルを使用して画像を生成します。
始める前に
4 個のチップを搭載した TPU v6e をプロビジョニングする準備を行います。
- Cloud TPU 環境を設定するガイドに沿って、 Google Cloud プロジェクトを設定して、Google Cloud CLI の構成を行い、Cloud TPU API を有効にして、Cloud TPU を使用するためのアクセス権があることを確認します。 
- Google Cloud で認証し、Google Cloud CLI のデフォルトのプロジェクトとゾーンを構成します。 - gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE 
容量を確保する
TPU 容量を確保する準備ができたら、Cloud TPU 割り当てで Cloud TPU 割り当ての詳細を確認してください。容量の確保についてご不明な点がございましたら、Cloud TPU のセールスチームまたはアカウント チームにお問い合わせください。
Cloud TPU 環境をプロビジョニングする
TPU VM は、GKE、GKE と XPK、またはキューに格納されたリソースとしてプロビジョニングできます。
前提条件
- プロジェクトに十分な TPUS_PER_TPU_FAMILY割り当てがあることを確認します。これは、Google Cloud プロジェクト内でアクセスできるチップの最大数を指します。
- プロジェクトの次の TPU 割り当てが十分にあることを確認します。- TPU VM の割り当て
- IP アドレスの割り当て
- Hyperdisk Balanced の割り当て
 
- ユーザー プロジェクトの権限
- GKE と XPK を使用している場合は、XPK の実行に必要な権限について、ユーザーまたはサービス アカウントに対する Google Cloud コンソールの権限をご覧ください。
 
TPU v6e をプロビジョニングする
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
list コマンドまたは describe コマンドを使用して、キューに格納されたリソースのステータスをクエリします。
gcloud alpha compute tpus queued-resources describe QUEUED_RESOURCE_ID \ --project=PROJECT_ID --zone=ZONE
キューに格納されたリソース リクエストのステータスの一覧については、キューに格納されたリソースのドキュメントをご覧ください。
SSH を使用して TPU に接続する
gcloud compute tpus tpu-vm ssh TPU_NAME
Conda 環境を作成する
- Miniconda のディレクトリを作成します。 - mkdir -p ~/miniconda3 
- Miniconda インストーラ スクリプトをダウンロードします。 - wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh 
- Miniconda をインストールします。 - bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 
- Miniconda インストーラ スクリプトを削除します。 - rm -rf ~/miniconda3/miniconda.sh 
- Miniconda を - PATH変数に追加します。- export PATH="$HOME/miniconda3/bin:$PATH" 
- ~/.bashrcを再読み込みして、- PATH変数に変更を適用します。- source ~/.bashrc 
- 新しい Conda 環境を作成します。 - conda create -n tpu python=3.10 
- Conda 環境をアクティブにします。 - source activate tpu 
MaxDiffusion を設定する
- MaxDiffusion GitHub リポジトリのクローンを作成し、MaxDiffusion ディレクトリに移動します。 - git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion 
- mlperf-4.1ブランチに切り替えます。- git checkout mlperf4.1 
- MaxDiffusion をインストールします。 - pip install -e . 
- 依存関係をインストールします。 - pip install -r requirements.txt 
- JAX をインストールします。 - pip install jax[tpu]==0.4.34 jaxlib==0.4.34 ml-dtypes==0.2.0 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 
- 追加の依存関係をインストールします。 - pip install huggingface_hub==0.25 absl-py flax tensorboardX google-cloud-storage torch tensorflow transformers 
画像を生成する
- 環境変数を設定して TPU ランタイムを構成します。 - LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536" 
- src/maxdiffusion/configs/base_xl.ymlで定義されたプロンプトと構成を使用して画像を生成します。- python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" - イメージが生成されたら、TPU リソースを必ずクリーンアップしてください。 
クリーンアップ
TPU を削除します。
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async