在 v6e TPU 虚拟机上进行 JetStream PyTorch 推理

本教程介绍了如何使用 JetStream 在 TPU v6e 上提供 PyTorch 模型。JetStream 是一款针对 XLA 设备 (TPU) 上的大语言模型 (LLM) 推理进行了吞吐量和内存优化的引擎。在本教程中,您将针对 Llama2-7B 模型运行推理基准测试。

准备工作

准备预配具有 4 个芯片的 TPU v6e:

  1. 登录您的 Google 账号。如果您还没有 Google 账号,请注册新账号
  2. Google Cloud 控制台中,从项目选择器页面选择创建一个 Google Cloud 项目。
  3. 为您的 Google Cloud 项目启用结算功能。所有 Google Cloud 使用都需要结算。
  4. 安装 gcloud alpha 组件
  5. 运行以下命令以安装最新版本的 gcloud 组件。

    gcloud components update
    
  6. 使用 Cloud Shell 通过以下 gcloud 命令启用 TPU API。您也可以从 Google Cloud 控制台启用。

    gcloud services enable tpu.googleapis.com
    
  7. 为 TPU 虚拟机创建服务身份。

    gcloud alpha compute tpus tpu-vm service-identity create --zone=ZONE
  8. 创建 TPU 服务账号,并授予对 Google Cloud 服务的访问权限。

    借助服务账号, Google Cloud TPU 服务可以访问其他 Google Cloud服务。建议使用用户代管式服务账号。请按照以下指南创建授予角色。您需要拥有以下角色:

    • TPU 管理员:创建 TPU 所需
    • Storage Admin:需要此角色才能访问 Cloud Storage
    • 日志写入器:需要使用 Logging API 写入日志
    • Monitoring Metric Writer:用于将指标写入 Cloud Monitoring
  9. 使用 Google Cloud 进行身份验证,并为 Google Cloud CLI 配置默认项目和区域。

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

保障容量

请与您的 Cloud TPU 销售团队或客户支持团队联系,申请 TPU 配额并咨询容量方面的任何问题。

预配 Cloud TPU 环境

您可以使用 GKE、GKE 和 XPK 预配 v6e TPU,也可以将其作为队列化资源预配。

前提条件

  • 验证您的项目是否有足够的 TPUS_PER_TPU_FAMILY 配额,该配额指定您可以在Google Cloud 项目中访问的条状标签的数量上限。
  • 本教程使用以下配置进行了测试:
    • Python 3.10 or later
    • 每夜软件版本:
      • 每夜 JAX 0.4.32.dev20240912
      • 每夜 LibTPU 0.1.dev20240912+nightly
    • 稳定版软件版本:
      • v0.4.35 的 JAX + JAX 库
  • 验证您的项目是否有足够的 TPU 配额,以便:
    • TPU 虚拟机配额
    • IP 地址配额
    • Hyperdisk Balanced 配额
  • 用户项目权限

创建环境变量

在 Cloud Shell 中,创建以下环境变量:

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-4
export ZONE=us-central2-b
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the
# default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

命令标志说明

变量 说明
NODE_ID 在队列化资源请求分配时创建的 TPU 的用户分配 ID。
PROJECT_ID Google Cloud 项目名称。使用现有项目或创建新项目
ZONE 如需了解支持的区域,请参阅 TPU 区域和可用区文档。
ACCELERATOR_TYPE 如需了解支持的加速器类型,请参阅加速器类型文档。
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT 这是您的服务账号的电子邮件地址,您可以在 Google Cloud 控制台 -> IAM -> 服务账号中找到该地址
例如:tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com
NUM_SLICES 要创建的 Slice 的数量(仅适用于多 Slice)
QUEUED_RESOURCE_ID 已加入队列的资源请求的用户分配的文本 ID。
VALID_DURATION 队列中资源请求的有效时长。
NETWORK_NAME 要使用的辅助网络的名称。
NETWORK_FW_NAME 要使用的次要网络防火墙的名称。

预配 TPU v6e

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
    

使用 listdescribe 命令查询队列中资源的状态。

   gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
      --project ${PROJECT_ID} --zone ${ZONE}

如需查看已加入队列的资源请求状态的完整列表,请参阅已加入队列的资源文档。

使用 SSH 连接到 TPU

  gcloud compute tpus tpu-vm ssh TPU_NAME

运行 JetStream PyTorch Llama2-7B 基准测试

如需设置 JetStream-PyTorch、转换模型检查点并运行推理基准测试,请按照 GitHub 代码库中的说明操作。

推理基准测试完成后,请务必清理 TPU 资源。

清理

删除 TPU:

   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --force \
      --async