Trillium (v6e) 简介

在本文档、TPU API 和日志中,v6e 用于指代 Trillium。v6e 代表 Google 第 6 代 TPU。

v6e 每个 Pod 包含 256 个芯片,与 v5e 有许多相似之处。此系统经过优化,可成为用于转换器、文本到图片和卷积神经网络 (CNN) 训练、微调和服务的最高价值产品。

v6e 系统架构

如需了解 Cloud TPU 配置,请参阅 v6e 文档。

本文档重点介绍使用 JAXPyTorchTensorFlow 框架进行模型训练的设置流程。对于每种框架,您都可以使用队列化资源或 Google Kubernetes Engine (GKE) 预配 TPU。您可以使用 XPK 或 GKE 命令进行 GKE 设置。

准备 Google Cloud 项目

  1. 登录您的 Google 账号。如果您还没有 Google 账号,请注册新账号
  2. Google Cloud 控制台中,从项目选择器页面选择创建 Cloud 项目。
  3. 为您的 Google Cloud 项目启用结算功能。所有 Google Cloud 使用情况都需要结算。
  4. 安装 gcloud alpha 组件
  5. 运行以下命令以安装最新版本的 gcloud 组件。

    gcloud components update
    
  6. 在 Cloud Shell 中通过以下 gcloud 命令启用 TPU API。您也可以从 Google Cloud 控制台启用。

    gcloud services enable tpu.googleapis.com
    
  7. 为 Compute Engine API 启用 TPU 服务账号权限

    通过服务账号,Cloud TPU 服务可以访问其他 Google Cloud 服务。用户代管式服务账号是 Google Cloud 的推荐做法。请按照以下指南创建授予角色。您需要拥有以下角色:

    • TPU Admin
    • Storage Admin
    • 日志写入者
    • Monitoring Metric Writer

    a. 使用您的用户账号为 GKE 设置 XPK 权限:XPK

  8. 为项目 ID 和可用区创建环境变量。

     gcloud auth login
     gcloud config set project ${PROJECT_ID}
     gcloud config set compute/zone ${ZONE}
    
  9. 为 TPU 虚拟机创建服务身份。

     gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    

保障容量

请与您的 Cloud TPU 支持销售/客户支持团队联系,申请 TPU 配额并解答容量方面的任何问题。

预配 Cloud TPU 环境

v6e TPU 可以使用 GKE、GKE 和 XPK(一种基于 GKE 的封装容器 CLI 工具)进行预配和管理,也可以作为队列化资源进行管理。

前提条件

  • 验证您的项目是否有足够的 TPUS_PER_TPU_FAMILY 配额,该配额指定您可以在 Google Cloud 项目中访问的最大条状标签数量。
  • v6e 已通过以下配置进行测试:
    • Python 3.10 或更高版本
    • 每夜软件版本:
      • 每夜 JAX 0.4.32.dev20240912
      • 每夜 LibTPU 0.1.dev20240912+nightly
    • 稳定版软件版本:
      • JAX + v0.4.35 的 JAX 库
  • 验证您的项目是否有足够的 TPU 配额,以便:
    • TPU 虚拟机配额
    • IP 地址配额
    • Hyperdisk-balance 配额
  • 用户项目权限

环境变量

在 Cloud Shell 中,创建以下环境变量:

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-central2-b
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the
# default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

命令标志说明

变量 说明
NODE_ID 用户分配给 TPU 的 ID,在队列中的资源请求分配时创建。
PROJECT_ID Google Cloud 项目名称。在 中使用现有项目或创建新项目
ZONE 如需了解支持的区域,请参阅 TPU 区域和可用区文档。
ACCELERATOR_TYPE 请参阅加速器类型
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT 这是您的服务账号的电子邮件地址,您可以在 Google Cloud Console -> IAM -> 服务账号中找到该地址

例如:tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

NUM_SLICES 要创建的 Slice 的数量(仅适用于多 Slice)
QUEUED_RESOURCE_ID 已加入队列的资源请求的用户分配的文本 ID。
VALID_DURATION 队列中的资源请求有效的时长。
NETWORK_NAME 要使用的辅助网络的名称。
NETWORK_FW_NAME 要使用的次要网络防火墙的名称。

网络性能优化

为了获得最佳性能,请使用 MTU(最大传输单元)为 8,896 的网络。

默认情况下,Virtual Private Cloud (VPC) 仅提供 1,460 字节的 MTU,这会导致网络性能不佳。您可以将 VPC 网络的 MTU 设置为 1300 字节到 8896 字节之间(含边界值)的任何值。常见的自定义 MTU 大小为 1500 字节(标准以太网)或 8896 字节(可能的最大值)。如需了解详情,请参阅有效的 VPC 网络 MTU 大小

如需详细了解如何更改现有网络或默认网络的 MTU 设置,请参阅更改 VPC 网络的 MTU 设置

以下示例会创建一个 MTU 为 8,896 的网络

export RESOURCE_NAME=RESOURCE_NAME
export NETWORK_NAME=${RESOURCE_NAME}
export NETWORK_FW_NAME=${RESOURCE_NAME}
export PROJECT=X
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} \

使用多 NIC(适用于多 Slice)

使用多 slice 环境时,辅助子网需要以下环境变量。

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2

使用以下命令为网络和子网创建自定义 IP 路由。

gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
   --bgp-routing-mode=regional --subnet-mode=custom --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 --region="${REGION}" \
   --project=$PROJECT

gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project="${PROJECT}"

gcloud compute routers create "${ROUTER_NAME}" \
  --project="${PROJECT}" \
  --network="${NETWORK_NAME_2}" \
  --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
  --router="${ROUTER_NAME}" \
  --region="${REGION}" \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project="${PROJECT}" \
  --enable-logging

创建多网络 slice 后,您可以通过在 XPK 工作负载中运行 --command ifconfig 来验证是否正在使用两个 NIC。然后,查看 Cloud 控制台日志中该 XPK 工作负载的输出,并检查 eth0 和 eth1 是否均为 mtu=8896。

python3 xpk.py workload create \
   --cluster ${CLUSTER_NAME} \
   (--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}) \
   --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone $ZONE \
   --project $PROJECT_ID \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

验证 eth0 和 eth1 是否均具有 mtu=8,896。您可以通过在 XPK 工作负载中运行命令 --command "ifconfig" 来验证是否正在运行多 NIC。然后,查看 Cloud 控制台日志中该 xpk 工作负载的输出,并检查 eth0 和 eth1 是否均为 mtu=8896。

改进了 TCP 设置

对于使用队列化资源接口创建的 TPU,您可以运行以下命令,通过更改 rto_minquickack 的默认 TCP 设置来提升网络性能。

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
   --project "$PROJECT" --zone "${ZONE}" \
   --command='ip route show | while IFS= read -r route; do if ! echo $route | \
   grep -q linkdown; then sudo ip route change ${route/lock/} rto_min 5ms quickack 1; fi; done' \
   --worker=all

使用已排队资源进行预配 (Cloud TPU API)

您可以使用队列化资源 create 命令预配容量。

  1. 创建 TPU 队列资源请求。

    --reserved 标志仅适用于预留资源,而不适用于按需资源。

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --valid-until-duration ${VALID_DURATION} \
      --service-account ${SERVICE_ACCOUNT} \
      [--reserved]

    如果成功创建了队列中的资源请求,“response”字段中的状态将为“WAITING_FOR_RESOURCES”或“FAILED”。如果已排队的资源请求处于“WAITING_FOR_RESOURCES”状态,则表示已排队的资源已加入队列,并会在有足够的 TPU 容量时进行预配。如果队列中的资源请求处于“FAILED”(失败)状态,输出中会显示失败原因。如果未在指定时长内预配 v6e,队列中的资源请求将过期,并且状态变为“FAILED”。如需了解详情,请参阅已加入队列的资源公开文档。

    当已排队的资源请求处于“ACTIVE”(有效)状态时,您可以使用 SSH 连接到 TPU 虚拟机。使用 listdescribe 命令查询队列中资源的状态。

    gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project ${PROJECT_ID} --zone ${ZONE}
    

    当队列中的资源处于“ACTIVE”(活动)状态时,输出类似于以下内容:

      state:
       state: ACTIVE
    
  2. 管理 TPU 虚拟机。如需了解管理 TPU 虚拟机的选项,请参阅管理 TPU 虚拟机

  3. 使用 SSH 连接到 TPU 虚拟机

    您可以在 TPU 切片中的每个 TPU 虚拟机上安装二进制文件并运行代码。请参阅虚拟机类型部分,确定您的 slice 将包含多少个虚拟机。

    如需安装二进制文件或运行代码,您可以使用 SSH 通过 tpu-vm ssh 命令连接到虚拟机。

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --node=all # add this flag if you are using Multislice
    

    如需使用 SSH 连接到特定虚拟机,请使用 --worker 标志,该标志应紧随从 0 开始编号的编号:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
    

    如果您的 slice 形状大于 8 个条状标签,则一个 slice 中将包含多个虚拟机。在这种情况下,请在 gcloud alpha compute tpus tpu-vm ssh 命令中使用 --worker=all--command 参数,以便在所有虚拟机上同时运行命令。例如:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
      --zone  ${ZONE} --worker=all \
      --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. 删除已排队的资源

    在会话结束时删除已排队的资源,或移除处于“FAILED”状态的已排队资源请求。如需删除已排队的资源,请按以下 2 个步骤删除 slice 和已排队的资源请求:

    gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \
     --zone=${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
     --project ${PROJECT_ID} --zone ${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
    

将 GKE 与 v6e 搭配使用

如果您将 GKE 命令与 v6e 搭配使用,则可以使用 Kubernetes 命令或 XPK 预配 TPU 并训练或部署模型。如需了解如何将 GKE 与 TPU 和 v6e 搭配使用,请参阅在 GKE 中规划 TPU

框架设置

本部分介绍了使用 JAXPyTorchTensorFlow 框架进行机器学习模型训练的一般设置流程。您可以使用队列化资源或 GKE 预配 TPU。您可以使用 XPK 或 Kubernetes 命令进行 GKE 设置。

使用排队的资源设置 JAX

使用 gcloud alpha compute tpus tpu-vm ssh 同时在切片中的所有 TPU 虚拟机上安装 JAX。对于多切片,请添加 --node=all


gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'

您可以运行以下 Python 代码,检查 slice 中可用的 TPU 核心数量,并测试是否已正确安装所有组件(此处显示的输出是使用 v6e-16 slice 生成的):

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

输出类似于以下内容:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() 会显示给定 slice 中的芯片总数。jax.local_device_count() 表示此 slice 中单个虚拟机可访问的芯片数量。

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 &&
   && pip install -r requirements.txt  && pip install . '

排查 JAX 设置问题

一般提示是在 GKE 工作负载清单中启用详细日志记录。然后,将日志提供给 GKE 支持团队。

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

错误消息

no endpoints available for service 'jobset-webhook-service'

此错误表示作业集未正确安装。检查 jobset-controller-manager 部署 Kubernetes Pod 是否正在运行。如需了解详情,请参阅 JobSet 问题排查文档

TPU initialization failed: Failed to connect

确保您的 GKE 节点版本为 1.30.4-gke.1348000 或更高版本(不支持 GKE 1.31)。

PyTorch 设置

本部分介绍了如何开始在 v6e 上使用 PyTorch/XLA 的 PJRT。建议使用 Python 3.10。

使用 XPK 通过 GKE 设置 PyTorch

您可以将以下 Docker 容器与已安装 PyTorch 依赖项的 XPK 搭配使用:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

如需创建 XPK 工作负载,请使用以下命令:

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -$NUM_SLICES \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES}  \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

使用 --base-docker-image 会创建一个新的 Docker 映像,并将当前工作目录内置到新的 Docker 中。

使用已排队的资源设置 PyTorch

请按照以下步骤使用队列化资源安装 PyTorch,并在 v6e 上运行一个小脚本。

使用 SSH 安装依赖项以访问虚拟机。

对于多切片,请添加 --node=all

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

提高具有可扩缩、频繁分配的模型的性能

对于具有较大规模、高频率的分配的模型,我们发现,与使用默认的 malloc 实现相比,使用 tcmalloc 可以显著提升性能,因此 TPU 虚拟机上默认使用的 malloctcmalloc。但是,根据您的工作负载(例如为其嵌入表进行了超大规模分配的 DLRM),tcmalloc 可能会造成运行缓慢,在这种情况下,您可以尝试设置以下变量以改为使用默认 malloc

unset LD_PRELOAD

使用 Python 脚本对 v6e 虚拟机执行计算:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

这将生成如下所示的输出:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

TensorFlow 设置

对于 v6e 公开预览版,仅支持 tf-nightly 运行时版本。

您可以通过运行以下命令,使用与 v6e 兼容的 TensorFlow 版本重置 tpu-runtime

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

使用 SSH 访问 worker-0:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

在 worker-0 上安装 TensorFlow:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

导出 TPU_NAME 环境变量:

export TPU_NAME=v6e-16

您可以运行以下 Python 脚本,检查 slice 中可用的 TPU 核心数量,并测试是否已正确安装所有内容(显示的输出是使用 v6e-16 slice 生成的):

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

输出类似于以下内容:

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

带有 SkyPilot 的 v6e

您可以将 TPU v6e 与 SkyPilot 搭配使用。请按照以下步骤向 SkyPilot 添加与 v6e 相关的位置/价格信息。

  1. 将以下代码添加到 ~/.sky/catalogs/v5/gcp/vms.csv 的末尾:

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. 在 YAML 文件中指定以下资源:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. 启动具有 TPU v6e 的集群:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. 使用 SSH 连接到 TPU v6e:ssh tpu_v6

推理教程

以下部分提供了使用 JetStream 提供 MaxText 和 PyTorch 模型以及在 TPU v6e 上提供 MaxDiffusion 模型的教程。

JetStream 上的 MaxText

本教程介绍了如何使用 JetStream 在 TPU v6e 上提供 MaxText (JAX) 模型。JetStream 是一款针对 XLA 设备 (TPU) 上的大语言模型 (LLM) 推理进行了吞吐量和内存优化的引擎。在本教程中,您将针对 Llama2-7B 模型运行推理基准测试。

准备工作

  1. 创建具有 4 个芯片的 TPU v6e:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. 使用 SSH 连接到 TPU:

    gcloud compute tpus tpu-vm ssh TPU_NAME

运行教程

如需设置 JetStream 和 MaxText、转换模型检查点并运行推理基准测试,请按照 GitHub 代码库中的说明操作。

清理

删除 TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

在 PyTorch TPU 上使用 vLLM

以下是一个简单的教程,介绍了如何在 TPU 虚拟机上开始使用 vLLM。我们将在未来几天内发布 GKE 用户指南,其中将介绍在生产环境中将 vLLM 部署到 Trillium 的最佳实践示例(敬请关注!)。

准备工作

  1. 创建一个包含 4 个芯片的 TPU v6e:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
       --node-id TPU_NAME \
       --project PROJECT_ID \
       --zone ZONE \
       --accelerator-type v6e-4 \
       --runtime-version v2-alpha-tpuv6e \
       --service-account SERVICE_ACCOUNT

    命令标志说明

    变量 说明
    NODE_ID 用户分配给 TPU 的 ID,该 ID 在队列中的资源请求分配时创建。
    PROJECT_ID Google Cloud 项目名称。在 中使用现有项目或创建新项目
    ZONE 如需了解支持的区域,请参阅 TPU 区域和可用区文档。
    ACCELERATOR_TYPE 请参阅加速器类型
    RUNTIME_VERSION v2-alpha-tpuv6e
    SERVICE_ACCOUNT 这是您的服务账号的电子邮件地址,您可以在 Google Cloud Console -> IAM -> 服务账号中找到该地址

    例如:tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

  2. 使用 SSH 连接到 TPU:

    gcloud compute tpus tpu-vm ssh TPU_NAME
    

Create a Conda environment

  1. (Recommended) Create a new conda environment for vLLM:

    conda create -n vllm python=3.10 -y
    conda activate vllm

在 TPU 上设置 vLLM

  1. 克隆 vLLM 代码库并进入 vLLM 目录:

    git clone https://github.com/vllm-project/vllm.git && cd vllm
    
  2. 清理现有的 torch 和 torch-xla 软件包:

    pip uninstall torch torch-xla -y
    
  3. 安装 PyTorch 和 PyTorch XLA:

    pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
    
  4. 安装 JAX 和 Pallas:

    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    
    
  5. 安装其他构建依赖项:

    pip install -r requirements-tpu.txt
    VLLM_TARGET_DEVICE="tpu" python setup.py develop
    sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
    

获取对模型的访问权限

您必须签署同意协议,才能使用 HuggingFace 代码库中的 Llama3 系列模型

如果您还没有 Hugging Face 令牌,请生成一个新令牌:

  1. 点击您的个人资料 > 设置 > 访问令牌
  2. 选择新建令牌 (New Token)。
  3. 指定您选择的名称和一个至少为 Read 的角色。
  4. 选择生成令牌
  5. 将生成的令牌复制到剪贴板,将其设置为环境变量,然后使用 huggingface-cli 进行身份验证:

    export TOKEN=''
    git config --global credential.helper store
    huggingface-cli login --token $TOKEN

下载基准比较数据

  1. 创建一个 /data 目录,然后从 Hugging Face 下载 ShareGPT 数据集。

    mkdir ~/data && cd ~/data
    wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
    

启动 vLLM 服务器

以下命令会将模型权重从 Hugging Face 模型中心下载到 TPU 虚拟机的 /tmp 目录,预编译一系列输入形状,并将模型编译结果写入 ~/.cache/vllm/xla_cache

如需了解详情,请参阅 vLLM 文档

   cd ~/vllm
   vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &

运行 vLLM 基准测试

运行 vLLM 基准测试脚本:

   python benchmarks/benchmark_serving.py \
       --backend vllm \
       --model "meta-llama/Meta-Llama-3.1-8B"  \
       --dataset-name sharegpt \
       --dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json  \
       --num-prompts 1000

清理

删除 TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

JetStream 上的 PyTorch

本教程介绍了如何使用 JetStream 在 TPU v6e 上提供 PyTorch 模型。 JetStream 是一款针对 XLA 设备 (TPU) 上的大语言模型 (LLM) 推理进行了吞吐量和内存优化的引擎。在本教程中,您将针对 Llama2-7B 模型运行推理基准测试。

准备工作

  1. 创建具有 4 个芯片的 TPU v6e:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. 使用 SSH 连接到 TPU:

    gcloud compute tpus tpu-vm ssh TPU_NAME

运行教程

如需设置 JetStream-PyTorch、转换模型检查点并运行推理基准测试,请按照 GitHub 代码库中的说明操作。

清理

删除 TPU:

   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --force \
      --async

MaxDiffusion 推理

本教程介绍了如何在 TPU v6e 上部署 MaxDiffusion 模型。在本教程中,您将使用 Stable Diffusion XL 模型生成图片。

准备工作

  1. 创建具有 4 个芯片的 TPU v6e:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. 使用 SSH 连接到 TPU:

    gcloud compute tpus tpu-vm ssh TPU_NAME

创建 Conda 环境

  1. 为 Miniconda 创建一个目录:

    mkdir -p ~/miniconda3
  2. 下载 Miniconda 安装程序脚本:

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
  3. 安装 Miniconda:

    bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
  4. 移除 Miniconda 安装程序脚本:

    rm -rf ~/miniconda3/miniconda.sh
  5. 将 Miniconda 添加到 PATH 变量:

    export PATH="$HOME/miniconda3/bin:$PATH"
  6. 重新加载 ~/.bashrc 以将更改应用于 PATH 变量:

    source ~/.bashrc
  7. 创建一个新的 Conda 环境:

    conda create -n tpu python=3.10
  8. 激活 Conda 环境:

    source activate tpu

设置 MaxDiffusion

  1. 克隆 MaxDiffusion 代码库并进入 MaxDiffusion 目录:

    https://github.com/google/maxdiffusion.git && cd maxdiffusion
  2. 切换到 mlperf-4.1 分支:

    git checkout mlperf4.1
  3. 安装 MaxDiffusion:

    pip install -e .
  4. 安装依赖项:

    pip install -r requirements.txt
  5. 安装 JAX:

    pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

生成图片

  1. 设置环境变量以配置 TPU 运行时:

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. 使用 src/maxdiffusion/configs/base_xl.yml 中定义的提示和配置生成图片:

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"

清理

删除 TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

培训教程

以下部分提供了 MaxText 训练教程,

TPU v6e 上的 MaxDiffusion 和 PyTorch 模型。

MaxText 和 MaxDiffusion

以下部分介绍了 MaxTextMaxDiffusion 模型的训练生命周期。

一般而言,大致步骤如下:

  1. 构建工作负载基础映像。
  2. 使用 XPK 运行工作负载。
    1. 为工作负载构建训练命令。
    2. 部署工作负载。
  3. 跟踪工作负载并查看指标。
  4. 如果不需要 XPK 工作负载,请将其删除。
  5. 不再需要 XPK 集群时,请将其删除。

构建基础映像

安装 MaxText 或 MaxDiffusion 并构建 Docker 映像:

  1. 克隆要使用的代码库,然后切换到该代码库的目录:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. 将 Docker 配置为使用 Google Cloud CLI:

    gcloud auth configure-docker
    
  3. 使用以下命令或 JAX 稳定版堆栈构建 Docker 映像。如需详细了解 JAX Stable Stack,请参阅使用 JAX Stable Stack 构建 Docker 映像

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. 如果您要从未在本地构建映像的机器启动工作负载,请上传映像:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
使用 JAX 稳定堆栈构建 Docker 映像

您可以使用 JAX Stable Stack 基本映像构建 MaxText 和 MaxDiffusion Docker 映像。

JAX 稳定版堆栈通过将 JAX 与 orbaxflaxoptax 等核心软件包以及经过充分限定的 libtpu.so 捆绑在一起,为 MaxText 和 MaxDiffusion 提供了一致的环境,以驱动 TPU 程序实用程序和其他基本工具。这些库经过测试以确保兼容性,为构建和运行 MaxText 和 MaxDiffusion 提供了稳定的基础,并消除了因软件包版本不兼容而导致的潜在冲突。

JAX 稳定版堆栈包含一个完全发布且经过认证的 libtpu.so,它是驱动 TPU 程序编译、执行和 ICI 网络配置的核心库。libtpu 版本取代了 JAX 之前使用的每夜 build,并通过 HLO/StableHLO IR 中的 PJRT 级资格测试确保 TPU 上的 XLA 计算功能一致。

如需使用 JAX 稳定堆栈构建 MaxText 和 MaxDiffusion Docker 映像,请在运行 docker_build_dependency_image.sh 脚本时,将 MODE 变量设置为 stable_stack,并将 BASEIMAGE 变量设置为要使用的基准映像。

以下示例将 us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 指定为基础映像:

bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

如需查看可用的 JAX 稳定堆栈基础映像的列表,请参阅 Artifact Registry 中的 JAX 稳定堆栈映像

使用 XPK 运行工作负载

  1. 如果您不使用 MaxText 设置的默认值MaxDiffusion,请设置以下环境变量:

    BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    PER_DEVICE_BATCH_SIZE=2
    NUM_STEPS=30
    MAX_TARGET_LENGTH=8192
  2. 构建模型脚本,以便在下一步中将其复制为训练命令。 暂时不要执行模型脚本。

    MaxText

    MaxText 是一个高性能、高度可伸缩的开源 LLM,采用纯 Python 和 JAX 编写,并以 Google Cloud TPU 和 GPU 为目标平台进行训练和推理。

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=$BASE_OUTPUT_DIR \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}"  # attention='dot_product'"
    

    Gemma2

    Gemma 是 Google DeepMind 开发的一系列开放权重大语言模型 (LLM),基于 Gemini 研究和技术。

    # Requires v6e-256
    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral 是 Mistral AI 开发的利用稀疏混合专家 (MoE) 架构的先进 AI 模型。

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama 是由 Meta 开发的一系列开放权重大语言模型 (LLM)。

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash"
    

    MaxDiffusion

    MaxDiffusion 是一系列用纯 Python 和 JAX 编写的参考实现,其中包含在 XLA 设备(包括 Cloud TPU 和 GPU)上运行的各种潜在 diffusion 模型。Stable Diffusion 是一种潜在的文本到图像模型,可根据任何文本输入生成逼真的图片。

    您需要安装特定分支才能运行 MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    训练脚本:

        cd maxdiffusion && OUT_DIR=${your_own_bucket}
        python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml \
            run_name=v6e-sd2 \
            split_head_dim=True \
            attention=flash \
            train_new_unet=false \
            norm_num_groups=16 \
            output_dir=${BASE_OUTPUT_DIR} \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            [dcn_data_parallelism=2] \
            enable_profiler=True \
            skip_first_n_steps_for_profiler=95 \
            max_train_steps=${NUM_STEPS} ]
            write_metrics=True'
        
  3. 使用您在上一步中创建的脚本运行模型。您必须指定 --base-docker-image 标志才能使用 MaxText 基础映像,或者指定 --docker-image 标志和要使用的映像。

    可选:您可以通过添加 --enable-debug-logs 标志来启用调试日志记录。如需了解详情,请参阅在 MaxText 上调试 JAX

    可选:您可以创建 Vertex AI 实验,通过添加 --use-vertex-tensorboard 标志将数据上传到 Vertex AI TensorBoard。如需了解详情,请参阅使用 Vertex AI 监控 MaxText 上的 JAX

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload ${USER}-xpk-ACCELERATOR_TYPE-NUM_SLICES \
        --tpu-type=ACCELERATOR_TYPE \
        --num-slices=NUM_SLICES  \
        --on-demand \
        --zone $ZONE \
        --project $PROJECT_ID \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command YOUR_MODEL_SCRIPT

    执行以下变量替换操作:

    • CLUSTER_NAME:XPK 集群的名称。
    • ACCELERATOR_TYPE:TPU 的版本和大小。例如 v6e-256
    • NUM_SLICES:TPU 切片的数量。
    • YOUR_MODEL_SCRIPT:要作为训练命令执行的模型脚本。

    输出中包含用于跟踪工作负载的链接,类似于以下内容:

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    打开链接,然后点击日志标签页以实时跟踪工作负载。

在 MaxText 上调试 JAX

使用补充 XPK 命令诊断集群或工作负载未运行的原因:

使用 Vertex AI 监控 MaxText 上的 JAX

通过 Vertex AI 的托管式 TensorBoard 查看标量和配置数据。

  1. 将您所用可用区的资源管理 (CRUD) 请求次数从 600 提高到 5,000。对于使用少于 16 个虚拟机的小型工作负载,这可能不是问题。
  2. 为 Vertex AI 安装 cloud-accelerator-diagnostics 等依赖项:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. 使用 --create-vertex-tensorboard 标志创建 XPK 集群,如创建 Vertex AI TensorBoard 中所述。您也可以在现有集群上运行此命令。

  4. 在运行 XPK 工作负载时,使用 --use-vertex-tensorboard 标志和可选的 --experiment-name 标志创建 Vertex AI 实验。如需查看完整步骤列表,请参阅创建 Vertex AI 实验以将数据上传到 Vertex AI TensorBoard

日志包含指向 Vertex AI TensorBoard 的链接,如下所示:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

您还可以在 Google Cloud 控制台中找到 Vertex AI TensorBoard 链接。前往 Google Cloud 控制台中的 Vertex AI 实验。从下拉菜单中选择适当的区域。

TensorBoard 目录也会写入您使用 ${BASE_OUTPUT_DIR} 指定的 Cloud Storage 存储桶。

删除 XPK 工作负载

您可以使用 xpk workload delete 命令根据作业前缀或作业状态删除一个或多个工作负载。如果您发送的 XPK 工作负载不再需要运行,或者您有作业卡在队列中,此命令可能会很有用。

删除 XPK 集群

使用 xpk cluster delete 命令删除集群:

python3 xpk.py cluster delete --cluster CLUSTER_NAME --zone $ZONE --project $PROJECT_ID

Llama 和 PyTorch

本教程介绍了如何使用 WikiText 数据集在 TPU v6e 上使用 PyTorch/XLA 训练 Llama 模型。此外,用户还可以在此处以 Docker 映像的形式访问 PyTorch TPU 模型 recrip。

安装

在虚拟环境中安装 pytorch-tpu/transformers 分支的 hugging face Transformer 和依赖项:

git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets
pip3 install evaluate
pip3 install scikit-learn
pip3 install accelerate

设置模型配置

下一部分(构建模型脚本)中的训练命令使用两个 JSON 配置文件来定义模型参数和 FSDP(完全分片数据并行)配置。 FSDP 分片用于在训练过程中调整模型权重以适应更大的批次大小。使用较小的模型进行训练时,只需使用数据并行处理并在每台设备上复制权重即可。 如需详细了解如何在 PyTorch/XLA 中跨设备分片张量,请参阅 PyTorch/XLA SPMD 用户指南

  1. 创建模型参数配置文件。以下是 Llama3-8B 的模型参数配置。对于其他模型,请在 Hugging Face 上查找配置。例如,请参阅 Llama2-7B 配置

    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
  2. 创建 FSDP 配置文件:

    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }

    如需详细了解 FSDP,请参阅 FSDPv2

  3. 使用以下命令将配置文件上传到 TPU 虚拟机:

        gcloud alpha compute tpus tpu-vm scp YOUR_CONFIG_FILE.json $TPU_NAME:. \
            --worker=all \
            --project=$PROJECT \
            --zone $ZONE

    您还可以在当前工作目录中创建配置文件,并在 XPK 中使用 --base-docker-image 标志。

构建模型脚本

构建模型脚本,使用 --config_name 标志指定模型参数配置文件,使用 --fsdp_config 标志指定 FSDP 配置文件。您将在下一部分(运行模型)中在 TPU 上运行此脚本。暂时不要执行模型脚本。

    PJRT_DEVICE=TPU
    XLA_USE_SPMD=1
    ENABLE_PJRT_COMPATIBILITY=true
    # Optional variables for debugging:
    XLA_IR_DEBUG=1
    XLA_HLO_DEBUG=1
    PROFILE_EPOCH=0
    PROFILE_STEP=3
    PROFILE_DURATION_MS=100000
    PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path
    python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 8 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/config-8B.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp_config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20

运行模型

使用您在上一步中创建的脚本构建模型脚本运行模型。

如果您使用的是单主机 TPU 虚拟机(例如 v6e-4),则可以直接在 TPU 虚拟机上运行训练命令。如果您使用的是多主机 TPU 虚拟机,请使用以下命令在所有主机上同时运行脚本:

gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
    --zone $ZONE \
    --worker=all \
    --command=YOUR_COMMAND

PyTorch/XLA 问题排查

如果您在上一部分中设置了用于调试的可选变量,则模型的配置文件将存储在变量 PROFILE_LOGDIR 指定的位置。您可以提取存储在此位置的 xplane.pb 文件,并使用 tensorboard 按照 TensorBoard 说明在浏览器中查看配置文件。如果 PyTorch/XLA 的运行情况不符合预期,请参阅问题排查指南,其中提供了有关调试、性能分析和优化模型的建议。

DLRM DCN v2 教程

本教程介绍如何在 TPU v6e 上训练 DLRM DCN v2 模型。

如果您在多主机上运行,请运行以下命令,使用适当的 TensorFlow 版本重置 tpu-runtime。如果您在单个主机上运行,则无需运行以下两个命令。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

通过 SSH 连接到 worker-0

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

设置 TPU 名称

export TPU_NAME=${TPU_NAME}

运行 DLRM v2

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

运行 script.sh

chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

运行推荐工作负载 (DLRM DCN) 需要以下标志:

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

基准测试结果

以下部分包含在 v6e 上针对 DLRM DCN v2 和 MaxDiffusion 进行的基准测试结果。

DLRM DCN v2

DLRM DCN v2 训练脚本在不同规模下运行。请参阅下表中的吞吐量。

v6e-64 v6e-128 v6e-256
训练步骤 7000 7000 7000
全局批次大小 131072 262144 524288
吞吐量(示例/秒) 2975334 5111808 10066329

MaxDiffusion

我们在 v6e-4、v6e-16 和 2xv6e-16 上运行了 MaxDiffusion 的训练脚本。请参阅下表中的吞吐量。

v6e-4 v6e-16 两个 v6e-16
训练步骤 0.069 0.073 0.13
全局批次大小 8 32 64
吞吐量(示例/秒) 115.9 438.4 492.3

集合

v6e 引入了一项名为“集合”的新功能,以便运行分发工作负载的用户获享便利。合集功能仅适用于 v6e。

借助集合,您可以向 Google Cloud 指明哪些 TPU 节点属于分发工作负载。这样,底层 Google Cloud 基础架构便可限制并简化在正常操作过程中可能会应用于训练工作负载的中断。

使用 Cloud TPU API 中的集合

Cloud TPU API 上的单主机集合是一种排队资源,其中设置了特殊标志 (--workload-type = availability-optimized),以向底层基础架构表明该资源用于处理工作负载。

以下命令使用 Cloud TPU API 预配单主机集合:

gcloud alpha compute tpus queued-resources create COLLECTION_NAME \
   --project=project name \
   --zone=zone name \
   --accelerator-type=accelerator type \
   --node-count=number of nodes \
   --workload-type=availability-optimized

监控和配置文件

Cloud TPU v6e 支持使用与上一代 Cloud TPU 相同的方法进行监控和性能分析。如需详细了解监控,请参阅监控 TPU 虚拟机