在 v6e TPU 上进行 MaxDiffusion 推理
本教程介绍了如何在 TPU v6e 上部署 MaxDiffusion 模型。在本教程中,您将使用 Stable Diffusion XL 模型生成图片。
准备工作
准备预配具有 4 个芯片的 TPU v6e:
- 登录您的 Google 账号。如果您还没有 Google 账号,请注册新账号。
- 在 Google Cloud 控制台中,从项目选择器页面选择或创建 Google Cloud 项目。
- 为您的 Google Cloud 项目启用结算功能。所有 Google Cloud 使用情况都需要结算。
- 安装 gcloud alpha 组件。
运行以下命令以安装最新版本的
gcloud
组件。gcloud components update
使用 Cloud Shell 通过以下
gcloud
命令启用 TPU API。您也可以从 Google Cloud 控制台启用。gcloud services enable tpu.googleapis.com
为 TPU 虚拟机创建服务身份。
gcloud alpha compute tpus tpu-vm service-identity create --zone=ZONE
创建 TPU 服务账号并授予对 Google Cloud 服务的访问权限。
通过服务账号, Google Cloud TPU 服务可以访问其他 Google Cloud服务。建议使用用户代管式服务账号。请按照以下指南创建和授予角色。您需要拥有以下角色:
- TPU 管理员:创建 TPU 所需
- Storage Admin:需要此角色才能访问 Cloud Storage
- 日志写入器:需要使用 Logging API 写入日志
- Monitoring Metric Writer:用于将指标写入 Cloud Monitoring
使用 Google Cloud 进行身份验证,并为 Google Cloud CLI 配置默认项目和可用区。
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
保障容量
请与您的 Cloud TPU 销售团队或客户支持团队联系,申请 TPU 配额并咨询容量方面的任何问题。
预配 Cloud TPU 环境
您可以使用 GKE、GKE 和 XPK 预配 v6e TPU,也可以将其作为队列化资源预配。
前提条件
- 验证您的项目是否有足够的
TPUS_PER_TPU_FAMILY
配额,该配额指定您可以在Google Cloud 项目中访问的芯片数量上限。 - 本教程使用以下配置进行了测试:
- Python
3.10 or later
- 每夜软件版本:
- 每夜 JAX
0.4.32.dev20240912
- 每夜 LibTPU
0.1.dev20240912+nightly
- 每夜 JAX
- 稳定版软件版本:
v0.4.35
的 JAX + JAX 库
- Python
- 验证您的项目是否有足够的 TPU 配额,以便:
- TPU 虚拟机配额
- IP 地址配额
- Hyperdisk Balanced 配额
- 用户项目权限
- 如果您将 GKE 与 XPK 搭配使用,请参阅用户账号或服务账号的 Cloud 控制台权限,了解运行 XPK 所需的权限。
预配 TPU v6e
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
使用 list
或 describe
命令查询队列中资源的状态。
gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE}
如需查看已加入队列的资源请求状态的完整列表,请参阅已加入队列的资源文档。
使用 SSH 连接到 TPU
gcloud compute tpus tpu-vm ssh TPU_NAME
创建 Conda 环境
为 Miniconda 创建一个目录:
mkdir -p ~/miniconda3
下载 Miniconda 安装程序脚本:
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
安装 Miniconda:
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
移除 Miniconda 安装程序脚本:
rm -rf ~/miniconda3/miniconda.sh
将 Miniconda 添加到
PATH
变量:export PATH="$HOME/miniconda3/bin:$PATH"
重新加载
~/.bashrc
以将更改应用于PATH
变量:source ~/.bashrc
创建一个新的 Conda 环境:
conda create -n tpu python=3.10
激活 Conda 环境:
source activate tpu
设置 MaxDiffusion
克隆 MaxDiffusion 代码库并进入 MaxDiffusion 目录:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
切换到
mlperf-4.1
分支:git checkout mlperf4.1
安装 MaxDiffusion:
pip install -e .
安装依赖项:
pip install -r requirements.txt
安装 JAX:
pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
安装其他依赖项:
pip install huggingface_hub==0.25 absl-py flax tensorboardX google-cloud-storage torch tensorflow transformers
生成图片
设置环境变量以配置 TPU 运行时:
LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
使用
src/maxdiffusion/configs/base_xl.yml
中定义的提示和配置生成图片:python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"
生成映像后,请务必清理 TPU 资源。
清理
删除 TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async