在 v6e TPU 虚拟机上进行 JetStream PyTorch 推理
本教程介绍了如何使用 JetStream 在 TPU v6e 上提供 PyTorch 模型。JetStream 是一款针对 XLA 设备 (TPU) 上的大语言模型 (LLM) 推理进行了吞吐量和内存优化的引擎。在本教程中,您将针对 Llama2-7B 模型运行推理基准测试。
准备工作
准备预配具有 4 个芯片的 TPU v6e:
- 登录您的 Google 账号。如果您还没有 Google 账号,请注册新账号。
- 在 Google Cloud 控制台中,从项目选择器页面选择或创建一个 Google Cloud 项目。
- 为您的 Google Cloud 项目启用结算功能。所有 Google Cloud 使用都需要结算。
- 安装 gcloud alpha 组件。
运行以下命令以安装最新版本的
gcloud
组件。gcloud components update
使用 Cloud Shell 通过以下
gcloud
命令启用 TPU API。您也可以从 Google Cloud 控制台启用。gcloud services enable tpu.googleapis.com
为 TPU 虚拟机创建服务身份。
gcloud alpha compute tpus tpu-vm service-identity create --zone=ZONE
创建 TPU 服务账号,并授予对 Google Cloud 服务的访问权限。
借助服务账号, Google Cloud TPU 服务可以访问其他 Google Cloud服务。建议使用用户代管式服务账号。请按照以下指南创建和授予角色。您需要拥有以下角色:
- TPU 管理员:创建 TPU 所需
- Storage Admin:需要此角色才能访问 Cloud Storage
- 日志写入器:需要使用 Logging API 写入日志
- Monitoring Metric Writer:用于将指标写入 Cloud Monitoring
使用 Google Cloud 进行身份验证,并为 Google Cloud CLI 配置默认项目和区域。
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
保障容量
请与您的 Cloud TPU 销售团队或客户支持团队联系,申请 TPU 配额并咨询容量方面的任何问题。
预配 Cloud TPU 环境
您可以使用 GKE、GKE 和 XPK 预配 v6e TPU,也可以将其作为队列化资源预配。
前提条件
- 验证您的项目是否有足够的
TPUS_PER_TPU_FAMILY
配额,该配额指定您可以在Google Cloud 项目中访问的条状标签的数量上限。 - 本教程使用以下配置进行了测试:
- Python
3.10 or later
- 每夜软件版本:
- 每夜 JAX
0.4.32.dev20240912
- 每夜 LibTPU
0.1.dev20240912+nightly
- 每夜 JAX
- 稳定版软件版本:
v0.4.35
的 JAX + JAX 库
- Python
- 验证您的项目是否有足够的 TPU 配额,以便:
- TPU 虚拟机配额
- IP 地址配额
- Hyperdisk Balanced 配额
- 用户项目权限
- 如果您将 GKE 与 XPK 搭配使用,请参阅用户账号或服务账号的 Cloud 控制台权限,了解运行 XPK 所需的权限。
创建环境变量
在 Cloud Shell 中,创建以下环境变量:
export NODE_ID=TPU_NODE_ID # TPU name export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=v6e-4 export ZONE=us-central2-b export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID export VALID_DURATION=VALID_DURATION # Additional environment variable needed for Multislice: export NUM_SLICES=NUM_SLICES # Use a custom network for better performance as well as to avoid having the # default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
命令标志说明
变量 | 说明 |
NODE_ID | 在队列化资源请求分配时创建的 TPU 的用户分配 ID。 |
PROJECT_ID | Google Cloud 项目名称。使用现有项目或创建新项目。 |
ZONE | 如需了解支持的区域,请参阅 TPU 区域和可用区文档。 |
ACCELERATOR_TYPE | 如需了解支持的加速器类型,请参阅加速器类型文档。 |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | 这是您的服务账号的电子邮件地址,您可以在 Google Cloud 控制台 -> IAM -> 服务账号中找到该地址
例如:tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com |
NUM_SLICES | 要创建的 Slice 的数量(仅适用于多 Slice) |
QUEUED_RESOURCE_ID | 已加入队列的资源请求的用户分配的文本 ID。 |
VALID_DURATION | 队列中资源请求的有效时长。 |
NETWORK_NAME | 要使用的辅助网络的名称。 |
NETWORK_FW_NAME | 要使用的次要网络防火墙的名称。 |
预配 TPU v6e
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
使用 list
或 describe
命令查询队列中资源的状态。
gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE}
如需查看已加入队列的资源请求状态的完整列表,请参阅已加入队列的资源文档。
使用 SSH 连接到 TPU
gcloud compute tpus tpu-vm ssh TPU_NAME
运行 JetStream PyTorch Llama2-7B 基准测试
如需设置 JetStream-PyTorch、转换模型检查点并运行推理基准测试,请按照 GitHub 代码库中的说明操作。
推理基准测试完成后,请务必清理 TPU 资源。
清理
删除 TPU:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--force \
--async