Trillium (v6e) 简介
在本文档、TPU API 和日志中,v6e 用于指代 Trillium。v6e 代表 Google 第 6 代 TPU。
v6e 架构每个 Pod 包含 256 个芯片,与 v5e 有许多相似之处。此系统针对转换器、文本到图像和卷积神经网络 (CNN) 训练、微调和服务进行了优化。
如需详细了解 v6e 系统架构和配置,请参阅 TPU v6e。
本简介文档重点介绍了使用 JAX、PyTorch 或 TensorFlow 框架进行模型训练和服务的流程。对于每种框架,您都可以使用队列化资源或 GKE 预配 TPU。您可以使用 XPK 或 GKE 命令进行 GKE 设置。
使用 v6e 训练或部署模型的一般流程
- 准备 Google Cloud 项目
- 安全容量
- 预配 Cloud TPU 环境
- 运行模型训练或推理工作负载
准备 Google Cloud 项目
在使用 Cloud TPU 之前,您需要:
- 创建 Google Cloud 已启用结算功能的账号和项目
- 安装 Google Cloud CLI Alpha 版组件
- 启用 Cloud TPU API
- 创建 Cloud TPU 服务代理
- 创建 Cloud TPU 服务账号并授予权限
如需了解详情,请参阅设置 Cloud TPU 环境。
保障容量
如需申请 Cloud TPU v6e 配额,或解答与容量有关的任何问题,请与 Google Cloud 支持团队联系。
预配 Cloud TPU 环境
v6e Cloud TPU 可以使用 GKE、GKE 和 XPK(一种基于 GKE 的封装容器 CLI 工具)进行预配和管理,也可以作为队列化资源进行管理。
前提条件
- 验证您的项目是否有足够的
TPUS_PER_TPU_FAMILY
配额,该配额指定您可以在 Google Cloud项目中访问的芯片数量上限。 - v6e 已通过以下配置进行测试:
- Python
3.10
或更高版本 - 每夜软件版本:
- 每夜 JAX
0.4.32.dev20240912
- 每夜 LibTPU
0.1.dev20240912+nightly
- 每夜 JAX
- 稳定版软件版本:
- JAX + v0.4.37 的 JAX 库
- Python
请验证您的项目是否有足够的配额来执行以下操作:
- Cloud TPU 虚拟机配额
- IP 地址配额
Hyperdisk Balanced 配额
如果您将 GKE 与 XPK 搭配使用,请参阅用户或服务账号的 Cloud 控制台权限,了解运行 XPK 所需的权限。
创建环境变量
在 Cloud Shell 中,创建以下环境变量:
export NODE_ID=your-tpu-name export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v6e-16 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id export VALID_DURATION=your-duration # Additional environment variable needed for provisioning Multislice: export NUM_SLICES=number-of-slices # Use a custom network for better performance as well as to avoid having the default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
命令标志说明
变量 | 说明 |
NODE_ID | 用户分配给 Cloud TPU 的 ID,该 ID 在分配已排队的资源请求时创建。 |
PROJECT_ID | Google Cloud 项目名称。使用现有项目或创建新项目。 如需了解详情,请参阅设置 Google Cloud 项目。 |
区域 | 如需了解支持的区域,请参阅 Cloud TPU 区域和可用区文档。 |
ACCELERATOR_TYPE | 请参阅加速器类型。 |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | 这是您的服务账号的电子邮件地址,您可以在 Google Cloud Console -> IAM -> 服务账号
例如:tpu-service-account@ |
NUM_SLICES | 要创建的 Slice 的数量(仅适用于多 Slice)。 |
QUEUED_RESOURCE_ID | 已加入队列的资源请求的用户分配的文本 ID。 |
VALID_DURATION | 队列中资源请求的有效时长。 |
NETWORK_NAME | 要使用的辅助网络的名称。 |
NETWORK_FW_NAME | 要使用的次要网络防火墙的名称。 |
优化网络性能
为了获得最佳性能,请使用 8,896 MTU(最大传输单元)的网络。
默认情况下,虚拟私有云 (VPC) 仅提供 1,460 字节的 MTU,这会导致网络性能不佳。您可以将 VPC 网络的 MTU 设置为 1300 字节到 8896 字节之间(含边界值)的任何值。常见的自定义 MTU 大小为 1500 字节(标准以太网)或 8896 字节(可能的最大值)。如需了解详情,请参阅有效的 VPC 网络 MTU 大小。
如需详细了解如何更改现有网络或默认网络的 MTU 设置,请参阅更改 VPC 网络的 MTU 设置。
以下示例会创建一个 MTU 为 8,896 的网络。
export RESOURCE_NAME=your-resource-name export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \ --allow tcp,icmp,udp --project=${PROJECT_ID}
使用多 NIC(适用于多 Slice)
使用多 slice 环境时,辅助子网需要以下环境变量。
export NETWORK_NAME_2=${RESOURCE_NAME} export SUBNET_NAME_2=${RESOURCE_NAME} export FIREWALL_RULE_NAME=${RESOURCE_NAME} export ROUTER_NAME=${RESOURCE_NAME}-network-2 export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2 export REGION=your-region
使用以下命令为网络和子网创建自定义 IP 路由。
gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
--bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
--network=${NETWORK_NAME_2} \
--range=10.10.0.0/18 --region=${REGION} \
--project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
--network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
--project=${PROJECT_ID} \
--network=${NETWORK_NAME_2} \
--region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
--router=${ROUTER_NAME} \
--region=${REGION} \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project=${PROJECT_ID} \
--enable-logging
创建多网络 slice 后,您可以通过设置 XPK 集群并将 --command ifconfig
标志添加到 XPK 工作负载创建命令,验证是否使用了两个网络接口卡 (NIC)。
使用以下 xpk workload
命令在 Google Cloud 控制台日志中显示 ifconfig
命令的输出,并检查 eth0 和 eth1 是否均为 mtu=8896。
python3 xpk.py workload create \ --cluster your-cluster-name \ (--base-docker-image maxtext_base_image|--docker-image your-cloud-image-name \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command "ifconfig"
验证 eth0 和 eth1 是否均为 mtu=8,896。您可以通过向 XPK 工作负载创建命令添加 --command ifconfig
标志来验证多 NIC 是否正在运行。在 Google Cloud 控制台日志中查看该 xpk 工作负载的输出,并验证 eth0 和 eth1 的 mtu 均为 8896。
改进 TCP 设置
如果您使用已排队的资源界面创建了 Cloud TPU,则可以运行以下命令,通过增加 TCP 接收缓冲区限制来提升网络性能。
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "${PROJECT}" \ --zone "${ZONE}" \ --node=all \ --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \ --worker=all
使用已排队的资源进行预配
您可以使用排队资源创建 Cloud TPU v6e。借助加入队列的资源,您可以在有容量可用时接收容量。您可以指定请求填充的开始时间和结束时间(可选)。如需了解详情,请参阅管理队列中的资源。
使用 GKE 或 XPK 预配 v6e Cloud TPU
如果您将 GKE 命令与 v6e 搭配使用,则可以使用 Kubernetes 命令或 XPK 预配 Cloud TPU,以及训练或部署模型。如需了解如何在 GKE 集群中规划 Cloud TPU 配置,请参阅在 GKE 中规划 Cloud TPU。以下部分提供了用于创建支持单个 NIC 和多 NIC 的 XPK 集群的命令。
创建支持单个 NIC 的 XPK 集群
export CLUSTER_NAME=xpk-cluster-name export ZONE=us-central2-b export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network=${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=n1-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments=${CLUSTER_ARGUMENTS} \ --create-vertex-tensorboard
命令标志说明
变量 | 说明 |
CLUSTER_NAME | XPK 集群的用户分配的名称。 |
PROJECT_ID | Google Cloud 项目名称。使用现有项目或创建新项目。 如需了解详情,请参阅设置 Google Cloud 项目。 |
区域 | 如需了解支持的区域,请参阅 Cloud TPU 区域和可用区文档。 |
TPU_TYPE | 请参阅加速器类型。 |
NUM_SLICES | 您要创建的 slice 的数量 |
CLUSTER_ARGUMENTS | 要使用的网络和子网。
例如: |
NUM_SLICES | 要创建的切片数量。 |
NETWORK_NAME | 要使用的辅助网络的名称。 |
NETWORK_FW_NAME | 要使用的次要网络防火墙的名称。 |
创建支持多 NIC 的 XPK 集群
export CLUSTER_NAME xpk-cluster-name export ZONE=us-central2-b export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export exportSUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \ --network=${NETWORK_NAME_1} \ --range=10.11.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_1} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_1} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster=${CLUSTER_NAME} \
--num-slices=${NUM_SLICES} \
--tpu-type=${TPU_TYPE} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--on-demand \
--custom-cluster-arguments=${CLUSTER_ARGUMENTS} \
--custom-nodepool-arguments=${NODE_POOL_ARGUMENTS} \
--create-vertex-tensorboard
命令标志说明
变量 | 说明 |
CLUSTER_NAME | XPK 集群的用户分配的名称。 |
PROJECT_ID | Google Cloud 项目名称。使用现有项目或创建新项目。 如需了解详情,请参阅设置 Google Cloud 项目。 |
区域 | 如需了解支持的区域,请参阅 Cloud TPU 区域和可用区文档。 |
TPU_TYPE | 请参阅加速器类型。 |
NUM_SLICES | 您要创建的 slice 的数量 |
CLUSTER_ARGUMENTS | 要使用的网络和子网。
例如: |
NODE_POOL_ARGUMENTS | 要使用的额外节点网络。
例如: |
NUM_SLICES | 要创建的 Slice 的数量(仅适用于多 Slice)。 |
NETWORK_NAME | 要使用的辅助网络的名称。 |
NETWORK_FW_NAME | 要使用的次要网络防火墙的名称。 |
框架设置
本部分介绍了使用 JAX、PyTorch 或 TensorFlow 框架进行机器学习模型训练的一般设置流程。如果您使用的是 GKE,则可以使用 XPK 或 Kubernetes 命令进行框架设置。
JAX 设置
本部分介绍了在 GKE 上运行 JAX 工作负载(无论是否使用 XPK)以及使用队列化资源的设置说明。
使用 GKE 设置 JAX
单个主机上的单个切片
以下示例使用 Kubernetes YAML 文件设置了 2x2 单主机节点池。
apiVersion: v1
kind: Pod
metadata:
name: tpu-pod-jax-v6e-a
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
containers:
- name: tpu-job
image: python:3.10
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
成功完成后,您应该会在 GKE 日志中看到以下消息:
Total TPU chips: 4
多主机上的单个切片
以下示例使用 Kubernetes YAML 文件设置了 4x4 多主机节点池。
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-available-chips
spec:
backoffLimit: 0
completions: 4
parallelism: 4
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: tpu-job
image: python:3.10
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
成功完成后,您应该会在 GKE 日志中看到以下消息:
Total TPU chips: 16
多主机上的多切片
以下示例使用 Kubernetes YAML 文件设置了两个 4x4 多主机节点池。
前提是,您需要安装 JobSet v0.2.3 或更高版本。
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: multislice-job
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
maxRestarts: 4
replicatedJobs:
- name: slice
replicas: 2
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
hostNetwork: true
containers:
- name: jax-tpu
image: python:3.10
ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 8431
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
limits:
google.com/tpu: 4
requests:
google.com/tpu: 4
成功完成后,您应该会在 GKE 日志中看到以下消息:
Total TPU chips: 32
如需了解详情,请参阅 GKE 文档中的运行多切片工作负载。
为了获得更好的性能,请启用 hostNetwork。
多 NIC
如需在 GKE 中充分利用多 NIC,Kubernetes Pod 清单需要添加其他注解。以下是非 TPU 多 NIC 工作负载示例清单。
apiVersion: v1
kind: Pod
metadata:
name: sample-netdevice-pod-1
annotations:
networking.gke.io/default-interface: 'eth0'
networking.gke.io/interfaces: |
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"netdevice-network"}
]
spec:
containers:
- name: sample-netdevice-pod
image: busybox
command: ["sleep", "infinity"]
ports:
- containerPort: 80
restartPolicy: Always
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
如果您使用 exec
命令连接到 Kubernetes Pod,则应使用以下代码看到额外的 NIC。
$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
inet 172.24.0.4/32 scope global eth1
valid_lft forever preferred_lft forever
使用 GKE 搭配 XPK 设置 JAX
如需使用 GKE 和 XPK 设置 JAX,请参阅 xpk README。
如需使用 MaxText 设置和运行 XPK,请参阅如何运行 MaxText。
使用已排队的资源设置 JAX
使用 gcloud alpha compute tpus tpu-vm ssh
命令同时在切片中的所有 Cloud TPU 虚拟机上安装 JAX。对于多 Slice,请添加 --node=all
标志。
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='pip install -U --pre jax jaxlib libtpu-nightly requests
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
您可以运行以下命令,检查您的 slice 中可用的 Cloud TPU 核心数量,并测试是否已正确安装所有组件:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
在 v6e-16 slice 上运行时,输出类似于以下内容:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count()
显示给定 slice 中的芯片总数。jax.local_device_count()
表示此 slice 中单个虚拟机可访问的芯片数量。
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
&& pip install -r requirements.txt && pip install . '
排查 JAX 设置问题
一般提示是在 GKE 工作负载清单中启用详细日志记录。然后,将日志提供给 GKE 支持团队。
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
错误消息
no endpoints available for service 'jobset-webhook-service'
此错误表示作业集未正确安装。检查 jobset-controller-manager 部署 Kubernetes Pod 是否正在运行。如需了解详情,请参阅 JobSet 问题排查文档。
TPU initialization failed: Failed to connect
确保您的 GKE 节点版本为 1.30.4-gke.1348000 或更高版本(不支持 GKE 1.31)。
PyTorch 设置
本部分介绍了如何开始在 v6e 上使用 PyTorch/XLA 的 PJRT。建议使用 Python 3.10。
使用 XPK 通过 GKE 设置 PyTorch
您可以将以下 Docker 容器与已安装 PyTorch 依赖项的 XPK 搭配使用:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
如需创建 XPK 工作负载,请使用以下命令:
python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES} \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
使用 --base-docker-image
会创建一个新的 Docker 映像,并将当前工作目录内置到新的 Docker 中。
使用已排队的资源设置 PyTorch
请按照以下步骤使用队列化资源安装 PyTorch,并在 v6e 上运行一个小脚本。
使用 SSH 安装依赖项以访问虚拟机
使用以下命令在所有 Cloud TPU 虚拟机上安装依赖项。对于多 Slice,请添加 --node=all
标志:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt install -y libopenblas-base pip3 \
install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
提高具有可扩缩、频繁分配的模型的性能
对于具有可扩缩、频繁分配的模型,与使用默认的 malloc
函数实现相比,使用 tcmalloc
函数可以显著提升性能,因此 Cloud TPU VM 上默认使用的 malloc
函数是 tcmalloc
。但是,根据您的工作负载(例如为其嵌入表进行了超大规模分配的 DLRM),tcmalloc
函数可能会造成运行缓慢,在这种情况下,您可以尝试设置以下变量以改为使用默认 malloc
函数:
unset LD_PRELOAD
使用 Python 脚本对 v6e 虚拟机执行计算
使用以下命令运行一个脚本,该脚本会创建两个张量,将它们相加,然后输出结果。
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--project ${PROJECT_ID} \
--zone ${ZONE} --worker all --command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'
这将生成如下所示的输出:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
TensorFlow 设置
您可以通过运行以下命令,使用与 v6e 兼容的 TensorFlow 版本重置 Cloud TPU 运行时:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
使用 SSH 访问 worker-0:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE}
在 worker-0 上安装 TensorFlow:
sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
导出 TPU_NAME
环境变量:
export TPU_NAME=v6e-16
您可以运行以下 Python 脚本,检查您的 slice 中可用的 Cloud TPU 核心数量,并测试所有内容是否已正确安装:
import TensorFlow as tf
print("TensorFlow version " + tf.__version__)
@tf.function
def add_fn(x,y):
z = x + y
return z
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
x = tf.constant(1.)
y = tf.constant(1.)
z = strategy.run(add_fn, args=(x,y))
print(z)
在 v6e-16 slice 上运行时,输出类似于以下内容:
PerReplica:{
0: tf.Tensor(2.0, shape=(), dtype=float32),
1: tf.Tensor(2.0, shape=(), dtype=float32),
2: tf.Tensor(2.0, shape=(), dtype=float32),
3: tf.Tensor(2.0, shape=(), dtype=float32),
4: tf.Tensor(2.0, shape=(), dtype=float32),
5: tf.Tensor(2.0, shape=(), dtype=float32),
6: tf.Tensor(2.0, shape=(), dtype=float32),
7: tf.Tensor(2.0, shape=(), dtype=float32)
}
配备 SkyPilot 的 v6e
您可以将 Cloud TPU v6e 与 SkyPilot 搭配使用。请按照以下步骤向 SkyPilot 添加与 v6e 相关的位置和价格信息。
将以下代码添加到
~/.sky/catalogs/v5/gcp/vms.csv
文件的末尾:,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
在 YAML 文件中指定以下资源:
# tpu_v6.yaml resources: accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use accelerator_args: runtime_version: v2-alpha-tpuv6e # Official suggested runtime
使用 Cloud TPU v6e 启动集群:
sky launch tpu_v6.yaml -c tpu_v6
使用 SSH 连接到 Cloud TPU v6e:
ssh tpu_v6
推理教程
以下教程介绍了如何在 Cloud TPU v6e 上运行推理:
训练示例
以下部分提供了在 Cloud TPU v6e 上训练 MaxText、MaxDiffusion 和 PyTorch 模型的示例。
在 v6e Cloud TPU 虚拟机上进行 MaxText 和 MaxDiffusion 训练
以下部分介绍了 MaxText 和 MaxDiffusion 模型的训练生命周期。
一般而言,大致步骤如下:
- 构建工作负载基础映像。
- 使用 XPK 运行工作负载。
- 为工作负载构建训练命令。
- 部署工作负载。
- 跟踪工作负载并查看指标。
- 如果不需要 XPK 工作负载,请将其删除。
- 不再需要 XPK 集群时,请将其删除。
构建基础映像
安装 MaxText 或 MaxDiffusion 并构建 Docker 映像:
克隆要使用的代码库,然后切换到该代码库的目录:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
将 Docker 配置为使用 Google Cloud CLI:
gcloud auth configure-docker
使用以下命令或 JAX 稳定版堆栈构建 Docker 映像。如需详细了解 JAX Stable Stack,请参阅使用 JAX Stable Stack 构建 Docker 映像。
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.37
如果您要从未在本地构建映像的机器启动工作负载,请上传映像:
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
使用 JAX 稳定版堆栈构建 Docker 映像
您可以使用 JAX Stable Stack 基本映像构建 MaxText 和 MaxDiffusion Docker 映像。
JAX 稳定堆栈通过将 JAX 与 orbax
、flax
和 optax
等核心软件包以及经过充分限定的 libtpu.so 捆绑在一起,为 MaxText 和 MaxDiffusion 提供了一致的环境,以驱动 Cloud TPU 程序实用程序和其他基本工具。这些库经过测试,以确保兼容性,并提供构建和运行 MaxText 和 MaxDiffusion 的稳定基础。这样可以消除因软件包版本不兼容而导致的潜在冲突。
JAX 稳定版堆栈包含一个完全发布且经过认证的 libtpu.so,它是驱动 Cloud TPU 程序编译、执行和 ICI 网络配置的核心库。libtpu 版本取代了 JAX 之前使用的每夜 build,并通过 HLO/StableHLO IR 中的 PJRT 级资格测试确保 XLA 计算在 Cloud TPU 上的功能一致。
如需使用 JAX 稳定版堆栈构建 MaxText 和 MaxDiffusion Docker 映像,请在运行 docker_build_dependency_image.sh
脚本时,将 MODE
变量设置为 stable_stack
,并将 BASEIMAGE
变量设置为要使用的基准映像。
以下示例将 us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
指定为基础映像:
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
如需查看可用的 JAX 稳定堆栈基础映像的列表,请参阅 Artifact Registry 中的 JAX 稳定堆栈映像。
使用 XPK 运行工作负载
如果您不使用 MaxText 设置的默认值或 MaxDiffusion 设置的默认值,请设置以下环境变量:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
构建模型脚本。在后续步骤中,此脚本将作为训练命令复制。
暂时不要执行模型脚本。
MaxText
MaxText 是一个高性能、高度可伸缩的开源 LLM,采用纯 Python 和 JAX 编写,可在 TPU 和 GPU 上进行训练和推理。 Google Cloud
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma 是 Google DeepMind 基于 Gemini 研究和技术开发的一系列开放权重 LLM。
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral 是 Mistral AI 开发的利用稀疏混合专家 (MoE) 架构的先进 AI 模型。
python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama 是由 Meta 开发的一系列开放权重 LLM。
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=llama3-8b \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ tokenizer_path=assets/tokenizer_llama3.tiktoken \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ attention=flash
MaxDiffusion
MaxDiffusion 是一系列用纯 Python 和 JAX 编写的参考实现,其中包含在 XLA 设备(包括 Cloud TPU 和 GPU)上运行的各种潜在 diffusion 模型。Stable Diffusion 是一种潜在的文本到图像模型,可根据任何文本输入生成逼真的图片。
您需要安装特定的 Git 分支才能运行 MaxDiffusion,如以下
git checkout
命令所示。git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 && pip install -r requirements.txt && pip install .
训练脚本:
cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \ python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash run_name=sdxl-ddp-v6e
使用您在上一步中创建的脚本运行模型。您必须指定
--base-docker-image
标志才能使用 MaxText 基础图片,或者指定--docker-image
标志和要使用的图片。可选:您可以通过添加
--enable-debug-logs
标志来启用调试日志记录。如需了解详情,请参阅在 MaxText 上调试 JAX。可选:您可以创建 Vertex AI 实验,通过添加
--use-vertex-tensorboard
标志将数据上传到 Vertex AI TensorBoard。如需了解详情,请参阅使用 Vertex AI 监控 MaxText 上的 JAX。python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command=$YOUR-MODEL-SCRIPT
导出以下变量:
export CLUSTER_NAME=CLUSTER_NAME: The name of your XPK cluster. export ACCELERATOR_TYPEACCELERATOR_TYPE: The version and size of your TPU. For example, `v6e-256`. export NUM_SLICES=NUM_SLICES: The number of Cloud TPU slices. export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: The model script to execute as a training command.
输出中包含用于跟踪工作负载的链接,类似于以下内容:
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
打开链接,然后点击日志标签页以实时跟踪工作负载。
在 MaxText 上调试 JAX
使用补充 XPK 命令诊断集群或工作负载未运行的原因。
- XPK 工作负载列表
- XPK 检查器
- 在创建 XPK 工作负载时,使用
--enable-debug-logs
标志在工作负载日志中启用详细日志记录。
使用 Vertex AI 监控 MaxText 上的 JAX
通过 Vertex AI 的托管式 TensorBoard 查看标量和性能数据。
- 将您所用可用区的资源管理 (CRUD) 请求次数从 600 提高到 5,000。对于使用少于 16 个虚拟机的小型工作负载,这可能不是问题。
为 Vertex AI 安装
cloud-accelerator-diagnostics
等依赖项:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
使用
--create-vertex-tensorboard
标志创建 XPK 集群,如创建 Vertex AI TensorBoard 中所述。您也可以在现有集群上运行此命令。在运行 XPK 工作负载时,使用
--use-vertex-tensorboard
标志和可选的--experiment-name
标志创建 Vertex AI 实验。如需查看完整步骤列表,请参阅创建 Vertex AI 实验以将数据上传到 Vertex AI TensorBoard。
日志包含指向 Vertex AI TensorBoard 的链接,如下所示:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
您还可以在 Google Cloud 控制台中找到 Vertex AI TensorBoard 链接。前往 Google Cloud 控制台中的 Vertex AI Experiments。从下拉菜单中选择适当的地区。
TensorBoard 目录也会写入您使用 ${BASE_OUTPUT_DIR}
指定的 Cloud Storage 存储分区。
删除 XPK 工作负载
您可以使用 xpk workload delete
命令根据作业前缀或作业状态删除一个或多个工作负载。如果您发送的 XPK 工作负载不再需要运行,或者您有作业卡在队列中,此命令可能会很有用。
删除 XPK 集群
使用 xpk cluster delete
命令删除集群:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone=${ZONE} --project=${PROJECT_ID}
在 v6e Cloud TPU 虚拟机上进行 Llama 和 PyTorch/XLA 训练
本教程介绍了如何使用 WikiText 数据集在 Cloud TPU v6e 上使用 PyTorch/XLA 训练 Llama 模型。
获取对 Hugging Face 和 Llama 3 模型的访问权限
您需要 Hugging Face 用户访问令牌才能运行本教程。如需了解如何创建和使用访问令牌,请参阅 Hugging Face 文档中的用户访问令牌部分。
您还需要有权访问 Hugging Face 上的 Llama 3 8B 模型。如需获取访问权限,请前往 HuggingFace 上的 Meta-Llama-3-8B 模型并请求访问权限。
创建 Cloud TPU 虚拟机
创建一个包含 8 个芯片的 Cloud TPU v6e 来运行本教程。
设置环境变量:
export ACCELERATOR_TYPE=v6e-8 export VERSION=v2-alpha-tpuv6e export TPU_NAME=$USER-$ACCELERATOR_TYPE export PROJECT_ID=your-project-id export ZONE=your-zone
创建 Cloud TPU 虚拟机:
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${VERSION} \ --accelerator-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID}
安装
安装 Hugging Face Transformer 的 pytorch-tpu/transformers
分支及其依赖项。本教程是使用以下示例中使用的依赖项版本进行测试的:
torch
:与 2.5.0 兼容torch_xla[tpu]
:与 2.5.0 兼容jax
:0.4.33jaxlib
:0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} --zone ${ZONE} \ --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
设置模型配置
下一部分(运行模型)中的训练命令使用两个 JSON 配置文件来定义模型参数和 FSDP(完全分片数据并行)配置。FSDP 分片用于模型权重,以便在训练过程中适应更大的批量大小。使用较小模型进行训练时,使用数据并行处理并在每台设备上复制权重可能就足够了。如需详细了解如何在 PyTorch/XLA 中跨设备分片张量,请参阅 PyTorch/XLA SPMD 用户指南。
创建模型参数配置文件。以下是 Llama3-8B 的模型参数配置。对于其他模型,请在 Hugging Face 上查找配置。例如,请参阅 Llama2-7B 配置。
cat > llama-config.json << EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
创建 FSDP 配置文件:
cat > fsdp-config.json << EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
如需详细了解 FSDP,请参阅 FSDPv2。
使用以下命令将配置文件上传到 Cloud TPU 虚拟机:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
运行模型
使用您在上一部分中创建的配置文件,运行 run_clm.py
脚本,以便在 WikiText 数据集上训练 Llama 3 8B 模型。训练脚本在 Cloud TPU v6e-8 上大约需要 10 分钟才能运行完毕。
使用以下命令在 Cloud TPU 上登录 Hugging Face:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \ --zone ${ZONE} \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
运行模型训练:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \ --zone ${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
PyTorch/XLA 问题排查
如果您在上一部分中设置了用于调试的可选变量,则模型的配置文件将存储在变量 PROFILE_LOGDIR
指定的位置。您可以提取存储在此位置的 xplane.pb
文件,并使用 tensorboard
按照 TensorBoard 说明在浏览器中查看配置文件。如果 PyTorch/XLA 的运行情况不符合预期,请参阅问题排查指南,其中提供了有关调试、性能分析和优化模型的建议。
在 v6e 上进行 DLRM DCN v2 训练
本教程介绍如何在 Cloud TPU v6e 上训练 DLRM DCN v2 模型。您需要预配 TPU v6e,其中包含 64、128 或 256 个芯片。
如果您在多主机 TPU 上运行,请运行以下命令,使用适当的 TensorFlow 版本重置 tpu-runtime
。如果您在单主机 TPU 上运行,则无需运行以下两个命令。
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID}
--zone ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
使用 SSH 连接到 worker-0
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}
设置 Cloud TPU 名称
export TPU_NAME=${TPU_NAME}
运行 DLRM v2
将以下代码段复制到名为 script.sh
的文件中:
pip install --user setuptools==65.5.0
pip install cloud-tpu-client
pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'
TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py --mode=train --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'mixed_bfloat16'
task:
use_synthetic_data: false
use_tf_record_reader: true
train_data:
input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
global_batch_size: 16384
use_cached_data: true
validation_data:
input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
global_batch_size: 16384
use_cached_data: true
model:
num_dense_features: 13
bottom_mlp: [512, 256, 128]
embedding_dim: 128
interaction: 'multi_layer_dcn'
dcn_num_layers: 3
dcn_low_rank_dim: 512
size_threshold: 8000
top_mlp: [1024, 1024, 512, 256, 1]
use_multi_hot: true
concat_dense: false
dcn_use_bias: true
vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
max_ids_per_chip_per_sample: 128
max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
use_partial_tpu_embedding: false
size_threshold: 0
initialize_tables_on_host: true
trainer:
train_steps: 10000
validation_interval: 1000
validation_steps: 660
summary_interval: 1000
steps_per_loop: 1000
checkpoint_interval: 0
optimizer_config:
embedding_optimizer: 'Adagrad'
dense_optimizer: 'Adagrad'
lr_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.025
warmup_steps: 0
dense_sgd_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.00025
warmup_steps: 8000
train_tf_function: true
train_tf_while_loop: true
eval_tf_while_loop: true
use_orbit: true
pipeline_sparse_and_dense_execution: true"
如果您在 GKE 上运行 TensorFlow,请使用以下命令安装 TensorFlow Cloud TPU wheel 和 libtpu:
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
设置以下标志,这些标志对于运行推荐工作负载(例如 DLRM DCN)至关重要:
ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"
运行 script.sh
:
chmod +x script.sh
./script.sh
基准测试结果
以下部分包含在 v6e 上针对 DLRM DCN v2 和 MaxDiffusion 进行的基准测试结果。
DLRM DCN v2
DLRM DCN v2 训练脚本在不同规模下运行。请参阅下表中的吞吐量。
v6e-64 | v6e-128 | v6e-256 | |
---|---|---|---|
训练步骤 | 7000 | 7000 | 7000 |
全局批量大小 | 131072 | 262144 | 524288 |
吞吐量(示例/秒) | 2975334 | 5111808 | 10066329 |
MaxDiffusion
我们在 v6e-4、v6e-16 和两个 v6e-16 上运行了 MaxDiffusion 的训练脚本。请参阅下表中的吞吐量。
v6e-4 | v6e-16 | 两个 v6e-16 | |
---|---|---|---|
训练步骤 | 0.069 | 0.073 | 0.13 |
全局批量大小 | 8 | 32 | 64 |
吞吐量(示例/秒) | 115.9 | 438.4 | 492.3 |