Trillium (v6e) 简介

在本文档、TPU API 和日志中,v6e 用于指代 Trillium。v6e 代表 Google 第 6 代 TPU。

v6e 架构每个 Pod 包含 256 个芯片,与 v5e 有许多相似之处。此系统针对转换器、文本到图像和卷积神经网络 (CNN) 训练、微调和服务进行了优化。

如需详细了解 v6e 系统架构和配置,请参阅 TPU v6e

本简介文档重点介绍了使用 JAXPyTorch TensorFlow 框架进行模型训练和服务的流程。对于每种框架,您都可以使用队列化资源或 GKE 预配 TPU。您可以使用 XPK 或 GKE 命令进行 GKE 设置。

使用 v6e 训练或部署模型的一般流程

  1. 准备 Google Cloud 项目
  2. 安全容量
  3. 预配 Cloud TPU 环境
  4. 运行模型训练推理工作负载

准备 Google Cloud 项目

在使用 Cloud TPU 之前,您需要:

  • 创建 Google Cloud 已启用结算功能的账号和项目
  • 安装 Google Cloud CLI Alpha 版组件
  • 启用 Cloud TPU API
  • 创建 Cloud TPU 服务代理
  • 创建 Cloud TPU 服务账号并授予权限

如需了解详情,请参阅设置 Cloud TPU 环境

保障容量

如需申请 Cloud TPU v6e 配额,或解答与容量有关的任何问题,请与 Google Cloud 支持团队联系。

预配 Cloud TPU 环境

v6e Cloud TPU 可以使用 GKE、GKE 和 XPK(一种基于 GKE 的封装容器 CLI 工具)进行预配和管理,也可以作为队列化资源进行管理。

前提条件

  • 验证您的项目是否有足够的 TPUS_PER_TPU_FAMILY 配额,该配额指定您可以在 Google Cloud项目中访问的芯片数量上限。
  • v6e 已通过以下配置进行测试:
    • Python 3.10 或更高版本
    • 每夜软件版本:
      • 每夜 JAX 0.4.32.dev20240912
      • 每夜 LibTPU 0.1.dev20240912+nightly
    • 稳定版软件版本:
      • JAX + v0.4.37 的 JAX 库
  • 请验证您的项目是否有足够的配额来执行以下操作:

    • Cloud TPU 虚拟机配额
    • IP 地址配额
    • Hyperdisk Balanced 配额

  • 如果您将 GKE 与 XPK 搭配使用,请参阅用户或服务账号的 Cloud 控制台权限,了解运行 XPK 所需的权限。

创建环境变量

在 Cloud Shell 中,创建以下环境变量:

export NODE_ID=your-tpu-name
export PROJECT_ID=your-project-id
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id
export VALID_DURATION=your-duration 

# Additional environment variable needed for provisioning Multislice:
export NUM_SLICES=number-of-slices

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

命令标志说明

变量 说明
NODE_ID 用户分配给 Cloud TPU 的 ID,该 ID 在分配已排队的资源请求时创建。
PROJECT_ID Google Cloud 项目名称。使用现有项目或创建新项目。 如需了解详情,请参阅设置 Google Cloud 项目
区域 如需了解支持的区域,请参阅 Cloud TPU 区域和可用区文档。
ACCELERATOR_TYPE 请参阅加速器类型
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT 这是您的服务账号的电子邮件地址,您可以在 Google Cloud Console -> IAM -> 服务账号

例如:tpu-service-account@.iam.gserviceaccount.com.com

NUM_SLICES 要创建的 Slice 的数量(仅适用于多 Slice)。
QUEUED_RESOURCE_ID 已加入队列的资源请求的用户分配的文本 ID。
VALID_DURATION 队列中资源请求的有效时长。
NETWORK_NAME 要使用的辅助网络的名称。
NETWORK_FW_NAME 要使用的次要网络防火墙的名称。

优化网络性能

为了获得最佳性能,请使用 8,896 MTU(最大传输单元)的网络。

默认情况下,虚拟私有云 (VPC) 仅提供 1,460 字节的 MTU,这会导致网络性能不佳。您可以将 VPC 网络的 MTU 设置为 1300 字节到 8896 字节之间(含边界值)的任何值。常见的自定义 MTU 大小为 1500 字节(标准以太网)或 8896 字节(可能的最大值)。如需了解详情,请参阅有效的 VPC 网络 MTU 大小

如需详细了解如何更改现有网络或默认网络的 MTU 设置,请参阅更改 VPC 网络的 MTU 设置

以下示例会创建一个 MTU 为 8,896 的网络。

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
 --allow tcp,icmp,udp --project=${PROJECT_ID}

使用多 NIC(适用于多 Slice)

使用多 slice 环境时,辅助子网需要以下环境变量。

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

使用以下命令为网络和子网创建自定义 IP 路由。

gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 --region=${REGION} \
   --project=${PROJECT_ID}

gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create ${ROUTER_NAME} \
  --project=${PROJECT_ID} \
  --network=${NETWORK_NAME_2} \
  --region=${REGION}

gcloud compute routers nats create ${NAT_CONFIG} \
  --router=${ROUTER_NAME} \
  --region=${REGION} \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project=${PROJECT_ID} \
  --enable-logging

创建多网络 slice 后,您可以通过设置 XPK 集群并将 --command ifconfig 标志添加到 XPK 工作负载创建命令,验证是否使用了两个网络接口卡 (NIC)。

使用以下 xpk workload 命令在 Google Cloud 控制台日志中显示 ifconfig 命令的输出,并检查 eth0 和 eth1 是否均为 mtu=8896。

python3 xpk.py workload create \
   --cluster your-cluster-name \
   (--base-docker-image maxtext_base_image|--docker-image your-cloud-image-name \
   --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

验证 eth0 和 eth1 是否均为 mtu=8,896。您可以通过向 XPK 工作负载创建命令添加 --command ifconfig 标志来验证多 NIC 是否正在运行。在 Google Cloud 控制台日志中查看该 xpk 工作负载的输出,并验证 eth0 和 eth1 的 mtu 均为 8896。

改进 TCP 设置

如果您使用已排队的资源界面创建了 Cloud TPU,则可以运行以下命令,通过增加 TCP 接收缓冲区限制来提升网络性能。

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "${PROJECT}" \
  --zone "${ZONE}" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

使用已排队的资源进行预配

您可以使用排队资源创建 Cloud TPU v6e。借助加入队列的资源,您可以在有容量可用时接收容量。您可以指定请求填充的开始时间和结束时间(可选)。如需了解详情,请参阅管理队列中的资源

使用 GKE 或 XPK 预配 v6e Cloud TPU

如果您将 GKE 命令与 v6e 搭配使用,则可以使用 Kubernetes 命令或 XPK 预配 Cloud TPU,以及训练或部署模型。如需了解如何在 GKE 集群中规划 Cloud TPU 配置,请参阅在 GKE 中规划 Cloud TPU。以下部分提供了用于创建支持单个 NIC 和多 NIC 的 XPK 集群的命令。

创建支持单个 NIC 的 XPK 集群

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
   gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
   gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
   python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=n1-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments=${CLUSTER_ARGUMENTS}  \
   --create-vertex-tensorboard

命令标志说明

变量 说明
CLUSTER_NAME XPK 集群的用户分配的名称。
PROJECT_ID Google Cloud 项目名称。使用现有项目或创建新项目。 如需了解详情,请参阅设置 Google Cloud 项目
区域 如需了解支持的区域,请参阅 Cloud TPU 区域和可用区文档。
TPU_TYPE 请参阅加速器类型
NUM_SLICES 您要创建的 slice 的数量
CLUSTER_ARGUMENTS 要使用的网络和子网。

例如:--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES 要创建的切片数量。
NETWORK_NAME 要使用的辅助网络的名称。
NETWORK_FW_NAME 要使用的次要网络防火墙的名称。

创建支持多 NIC 的 XPK 集群

export CLUSTER_NAME xpk-cluster-name
export ZONE=us-central2-b
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export exportSUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
   gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
   gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
   gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
  gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_1} \
    --region=${REGION}
  gcloud compute routers nats create ${NAT_CONFIG} \
     --router=${ROUTER_NAME} \
     --region=${REGION} \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project=${PROJECT_ID} \
     --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
   gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
   gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
   gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
   gcloud compute routers create ${ROUTER_NAME} \
     --project=${PROJECT_ID} \
     --network=${NETWORK_NAME_2} \
     --region=${REGION}
   gcloud compute routers nats create ${NAT_CONFIG} \
     --router=${ROUTER_NAME} \
     --region=${REGION} \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project=${PROJECT_ID} \
     --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster=${CLUSTER_NAME} \
--num-slices=${NUM_SLICES} \
--tpu-type=${TPU_TYPE} \
--zone=${ZONE}  \
--project=${PROJECT_ID} \
--on-demand \
--custom-cluster-arguments=${CLUSTER_ARGUMENTS} \
--custom-nodepool-arguments=${NODE_POOL_ARGUMENTS} \
--create-vertex-tensorboard

命令标志说明

变量 说明
CLUSTER_NAME XPK 集群的用户分配的名称。
PROJECT_ID Google Cloud 项目名称。使用现有项目或创建新项目。 如需了解详情,请参阅设置 Google Cloud 项目
区域 如需了解支持的区域,请参阅 Cloud TPU 区域和可用区文档。
TPU_TYPE 请参阅加速器类型
NUM_SLICES 您要创建的 slice 的数量
CLUSTER_ARGUMENTS 要使用的网络和子网。

例如:--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS 要使用的额外节点网络。

例如:--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES 要创建的 Slice 的数量(仅适用于多 Slice)。
NETWORK_NAME 要使用的辅助网络的名称。
NETWORK_FW_NAME 要使用的次要网络防火墙的名称。

框架设置

本部分介绍了使用 JAXPyTorchTensorFlow 框架进行机器学习模型训练的一般设置流程。如果您使用的是 GKE,则可以使用 XPK 或 Kubernetes 命令进行框架设置。

JAX 设置

本部分介绍了在 GKE 上运行 JAX 工作负载(无论是否使用 XPK)以及使用队列化资源的设置说明。

使用 GKE 设置 JAX

单个主机上的单个切片

以下示例使用 Kubernetes YAML 文件设置了 2x2 单主机节点池。

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

成功完成后,您应该会在 GKE 日志中看到以下消息:

Total TPU chips: 4

多主机上的单个切片

以下示例使用 Kubernetes YAML 文件设置了 4x4 多主机节点池。

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

成功完成后,您应该会在 GKE 日志中看到以下消息:

Total TPU chips: 16

多主机上的多切片

以下示例使用 Kubernetes YAML 文件设置了两个 4x4 多主机节点池。

前提是,您需要安装 JobSet v0.2.3 或更高版本。

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

成功完成后,您应该会在 GKE 日志中看到以下消息:

Total TPU chips: 32

如需了解详情,请参阅 GKE 文档中的运行多切片工作负载

为了获得更好的性能,请启用 hostNetwork

多 NIC

如需在 GKE 中充分利用多 NIC,Kubernetes Pod 清单需要添加其他注解。以下是非 TPU 多 NIC 工作负载示例清单。

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

如果您使用 exec 命令连接到 Kubernetes Pod,则应使用以下代码看到额外的 NIC。

$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

使用 GKE 搭配 XPK 设置 JAX

如需使用 GKE 和 XPK 设置 JAX,请参阅 xpk README

如需使用 MaxText 设置和运行 XPK,请参阅如何运行 MaxText

使用已排队的资源设置 JAX

使用 gcloud alpha compute tpus tpu-vm ssh 命令同时在切片中的所有 Cloud TPU 虚拟机上安装 JAX。对于多 Slice,请添加 --node=all 标志。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests
 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

您可以运行以下命令,检查您的 slice 中可用的 Cloud TPU 核心数量,并测试是否已正确安装所有组件:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

在 v6e-16 slice 上运行时,输出类似于以下内容:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() 显示给定 slice 中的芯片总数。jax.local_device_count() 表示此 slice 中单个虚拟机可访问的芯片数量。

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   && pip install -r requirements.txt  && pip install . '

排查 JAX 设置问题

一般提示是在 GKE 工作负载清单中启用详细日志记录。然后,将日志提供给 GKE 支持团队。

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

错误消息

no endpoints available for service 'jobset-webhook-service'

此错误表示作业集未正确安装。检查 jobset-controller-manager 部署 Kubernetes Pod 是否正在运行。如需了解详情,请参阅 JobSet 问题排查文档

TPU initialization failed: Failed to connect

确保您的 GKE 节点版本为 1.30.4-gke.1348000 或更高版本(不支持 GKE 1.31)。

PyTorch 设置

本部分介绍了如何开始在 v6e 上使用 PyTorch/XLA 的 PJRT。建议使用 Python 3.10。

使用 XPK 通过 GKE 设置 PyTorch

您可以将以下 Docker 容器与已安装 PyTorch 依赖项的 XPK 搭配使用:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

如需创建 XPK 工作负载,请使用以下命令:

python3 xpk.py workload create \
    --cluster ${CLUSTER_NAME} \
    [--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
    --workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
    --tpu-type=${ACCELERATOR_TYPE} \
    --num-slices=${NUM_SLICES}  \
    --on-demand \
    --zone ${ZONE} \
    --project ${PROJECT_ID} \
    --enable-debug-logs \
    --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

使用 --base-docker-image 会创建一个新的 Docker 映像,并将当前工作目录内置到新的 Docker 中。

使用已排队的资源设置 PyTorch

请按照以下步骤使用队列化资源安装 PyTorch,并在 v6e 上运行一个小脚本。

使用 SSH 安装依赖项以访问虚拟机

使用以下命令在所有 Cloud TPU 虚拟机上安装依赖项。对于多 Slice,请添加 --node=all 标志:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

提高具有可扩缩、频繁分配的模型的性能

对于具有可扩缩、频繁分配的模型,与使用默认的 malloc 函数实现相比,使用 tcmalloc 函数可以显著提升性能,因此 Cloud TPU VM 上默认使用的 malloc 函数是 tcmalloc。但是,根据您的工作负载(例如为其嵌入表进行了超大规模分配的 DLRM),tcmalloc 函数可能会造成运行缓慢,在这种情况下,您可以尝试设置以下变量以改为使用默认 malloc 函数:

unset LD_PRELOAD

使用 Python 脚本对 v6e 虚拟机执行计算

使用以下命令运行一个脚本,该脚本会创建两个张量,将它们相加,然后输出结果。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

这将生成如下所示的输出:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

TensorFlow 设置

您可以通过运行以下命令,使用与 v6e 兼容的 TensorFlow 版本重置 Cloud TPU 运行时:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

使用 SSH 访问 worker-0:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

在 worker-0 上安装 TensorFlow:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

导出 TPU_NAME 环境变量:

export TPU_NAME=v6e-16

您可以运行以下 Python 脚本,检查您的 slice 中可用的 Cloud TPU 核心数量,并测试所有内容是否已正确安装:

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

在 v6e-16 slice 上运行时,输出类似于以下内容:

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

配备 SkyPilot 的 v6e

您可以将 Cloud TPU v6e 与 SkyPilot 搭配使用。请按照以下步骤向 SkyPilot 添加与 v6e 相关的位置和价格信息。

  1. 将以下代码添加到 ~/.sky/catalogs/v5/gcp/vms.csv 文件的末尾:

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. 在 YAML 文件中指定以下资源:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. 使用 Cloud TPU v6e 启动集群:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. 使用 SSH 连接到 Cloud TPU v6e:ssh tpu_v6

推理教程

以下教程介绍了如何在 Cloud TPU v6e 上运行推理:

训练示例

以下部分提供了在 Cloud TPU v6e 上训练 MaxText、MaxDiffusion 和 PyTorch 模型的示例。

在 v6e Cloud TPU 虚拟机上进行 MaxText 和 MaxDiffusion 训练

以下部分介绍了 MaxTextMaxDiffusion 模型的训练生命周期。

一般而言,大致步骤如下:

  1. 构建工作负载基础映像。
  2. 使用 XPK 运行工作负载。
    1. 为工作负载构建训练命令。
    2. 部署工作负载。
  3. 跟踪工作负载并查看指标。
  4. 如果不需要 XPK 工作负载,请将其删除。
  5. 不再需要 XPK 集群时,请将其删除。

构建基础映像

安装 MaxText 或 MaxDiffusion 并构建 Docker 映像:

  1. 克隆要使用的代码库,然后切换到该代码库的目录:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. 将 Docker 配置为使用 Google Cloud CLI:

    gcloud auth configure-docker
    
  3. 使用以下命令或 JAX 稳定版堆栈构建 Docker 映像。如需详细了解 JAX Stable Stack,请参阅使用 JAX Stable Stack 构建 Docker 映像

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.37
    
  4. 如果您要从未在本地构建映像的机器启动工作负载,请上传映像:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
使用 JAX 稳定版堆栈构建 Docker 映像

您可以使用 JAX Stable Stack 基本映像构建 MaxText 和 MaxDiffusion Docker 映像。

JAX 稳定堆栈通过将 JAX 与 orbaxflaxoptax 等核心软件包以及经过充分限定的 libtpu.so 捆绑在一起,为 MaxText 和 MaxDiffusion 提供了一致的环境,以驱动 Cloud TPU 程序实用程序和其他基本工具。这些库经过测试,以确保兼容性,并提供构建和运行 MaxText 和 MaxDiffusion 的稳定基础。这样可以消除因软件包版本不兼容而导致的潜在冲突。

JAX 稳定版堆栈包含一个完全发布且经过认证的 libtpu.so,它是驱动 Cloud TPU 程序编译、执行和 ICI 网络配置的核心库。libtpu 版本取代了 JAX 之前使用的每夜 build,并通过 HLO/StableHLO IR 中的 PJRT 级资格测试确保 XLA 计算在 Cloud TPU 上的功能一致。

如需使用 JAX 稳定版堆栈构建 MaxText 和 MaxDiffusion Docker 映像,请在运行 docker_build_dependency_image.sh 脚本时,将 MODE 变量设置为 stable_stack,并将 BASEIMAGE 变量设置为要使用的基准映像。

以下示例将 us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1 指定为基础映像:

bash docker_build_dependency_image.sh MODE=stable_stack
BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1

如需查看可用的 JAX 稳定堆栈基础映像的列表,请参阅 Artifact Registry 中的 JAX 稳定堆栈映像

使用 XPK 运行工作负载

  1. 如果您不使用 MaxText 设置的默认值MaxDiffusion 设置的默认值,请设置以下环境变量:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. 构建模型脚本。在后续步骤中,此脚本将作为训练命令复制。

    暂时不要执行模型脚本。

    MaxText

    MaxText 是一个高性能、高度可伸缩的开源 LLM,采用纯 Python 和 JAX 编写,可在 TPU 和 GPU 上进行训练和推理。 Google Cloud

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma 是 Google DeepMind 基于 Gemini 研究和技术开发的一系列开放权重 LLM。

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral 是 Mistral AI 开发的利用稀疏混合专家 (MoE) 架构的先进 AI 模型。

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama 是由 Meta 开发的一系列开放权重 LLM。

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash
    

    MaxDiffusion

    MaxDiffusion 是一系列用纯 Python 和 JAX 编写的参考实现,其中包含在 XLA 设备(包括 Cloud TPU 和 GPU)上运行的各种潜在 diffusion 模型。Stable Diffusion 是一种潜在的文本到图像模型,可根据任何文本输入生成逼真的图片。

    您需要安装特定的 Git 分支才能运行 MaxDiffusion,如以下 git checkout 命令所示。

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    训练脚本:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
        
  3. 使用您在上一步中创建的脚本运行模型。您必须指定 --base-docker-image 标志才能使用 MaxText 基础图片,或者指定 --docker-image 标志和要使用的图片。

    可选:您可以通过添加 --enable-debug-logs 标志来启用调试日志记录。如需了解详情,请参阅在 MaxText 上调试 JAX

    可选:您可以创建 Vertex AI 实验,通过添加 --use-vertex-tensorboard 标志将数据上传到 Vertex AI TensorBoard。如需了解详情,请参阅使用 Vertex AI 监控 MaxText 上的 JAX

    python3 xpk.py workload create \
        --cluster ${CLUSTER_NAME} \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command=$YOUR-MODEL-SCRIPT

    导出以下变量:

    export CLUSTER_NAME=CLUSTER_NAME: The name of your XPK cluster.
    export ACCELERATOR_TYPEACCELERATOR_TYPE: The version and size of your TPU. For example, `v6e-256`.
    export NUM_SLICES=NUM_SLICES: The number of Cloud TPU slices.
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: The model script to execute as a training command.

    输出中包含用于跟踪工作负载的链接,类似于以下内容:

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    打开链接,然后点击日志标签页以实时跟踪工作负载。

在 MaxText 上调试 JAX

使用补充 XPK 命令诊断集群或工作负载未运行的原因。

使用 Vertex AI 监控 MaxText 上的 JAX

通过 Vertex AI 的托管式 TensorBoard 查看标量和性能数据。

  1. 将您所用可用区的资源管理 (CRUD) 请求次数从 600 提高到 5,000。对于使用少于 16 个虚拟机的小型工作负载,这可能不是问题。
  2. 为 Vertex AI 安装 cloud-accelerator-diagnostics 等依赖项:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. 使用 --create-vertex-tensorboard 标志创建 XPK 集群,如创建 Vertex AI TensorBoard 中所述。您也可以在现有集群上运行此命令。

  4. 在运行 XPK 工作负载时,使用 --use-vertex-tensorboard 标志和可选的 --experiment-name 标志创建 Vertex AI 实验。如需查看完整步骤列表,请参阅创建 Vertex AI 实验以将数据上传到 Vertex AI TensorBoard

日志包含指向 Vertex AI TensorBoard 的链接,如下所示:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

您还可以在 Google Cloud 控制台中找到 Vertex AI TensorBoard 链接。前往 Google Cloud 控制台中的 Vertex AI Experiments。从下拉菜单中选择适当的地区。

TensorBoard 目录也会写入您使用 ${BASE_OUTPUT_DIR} 指定的 Cloud Storage 存储分区。

删除 XPK 工作负载

您可以使用 xpk workload delete 命令根据作业前缀或作业状态删除一个或多个工作负载。如果您发送的 XPK 工作负载不再需要运行,或者您有作业卡在队列中,此命令可能会很有用。

删除 XPK 集群

使用 xpk cluster delete 命令删除集群:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone=${ZONE} --project=${PROJECT_ID}

在 v6e Cloud TPU 虚拟机上进行 Llama 和 PyTorch/XLA 训练

本教程介绍了如何使用 WikiText 数据集在 Cloud TPU v6e 上使用 PyTorch/XLA 训练 Llama 模型。

获取对 Hugging Face 和 Llama 3 模型的访问权限

您需要 Hugging Face 用户访问令牌才能运行本教程。如需了解如何创建和使用访问令牌,请参阅 Hugging Face 文档中的用户访问令牌部分

您还需要有权访问 Hugging Face 上的 Llama 3 8B 模型。如需获取访问权限,请前往 HuggingFace 上的 Meta-Llama-3-8B 模型并请求访问权限。

创建 Cloud TPU 虚拟机

创建一个包含 8 个芯片的 Cloud TPU v6e 来运行本教程。

  1. 设置环境变量:

    export ACCELERATOR_TYPE=v6e-8
    export VERSION=v2-alpha-tpuv6e
    export TPU_NAME=$USER-$ACCELERATOR_TYPE
    export PROJECT_ID=your-project-id
    export ZONE=your-zone
  2. 创建 Cloud TPU 虚拟机:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${VERSION} \
        --accelerator-type=${ACCELERATOR_TYPE} \
        --zone=${ZONE} \
        --project=${PROJECT_ID}

安装

安装 Hugging Face Transformer 的 pytorch-tpu/transformers 分支及其依赖项。本教程是使用以下示例中使用的依赖项版本进行测试的:

  • torch:与 2.5.0 兼容
  • torch_xla[tpu]:与 2.5.0 兼容
  • jax:0.4.33
  • jaxlib:0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} --zone ${ZONE} \
    --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    cd transformers
    sudo pip3 install -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

设置模型配置

下一部分(运行模型)中的训练命令使用两个 JSON 配置文件来定义模型参数和 FSDP(完全分片数据并行)配置。FSDP 分片用于模型权重,以便在训练过程中适应更大的批量大小。使用较小模型进行训练时,使用数据并行处理并在每台设备上复制权重可能就足够了。如需详细了解如何在 PyTorch/XLA 中跨设备分片张量,请参阅 PyTorch/XLA SPMD 用户指南

  1. 创建模型参数配置文件。以下是 Llama3-8B 的模型参数配置。对于其他模型,请在 Hugging Face 上查找配置。例如,请参阅 Llama2-7B 配置

    cat > llama-config.json << EOF
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
  2. 创建 FSDP 配置文件:

    cat > fsdp-config.json << EOF
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF

    如需详细了解 FSDP,请参阅 FSDPv2

  3. 使用以下命令将配置文件上传到 Cloud TPU 虚拟机:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
        --worker=all \
        --project=${PROJECT_ID} \
        --zone=${ZONE}

运行模型

使用您在上一部分中创建的配置文件,运行 run_clm.py 脚本,以便在 WikiText 数据集上训练 Llama 3 8B 模型。训练脚本在 Cloud TPU v6e-8 上大约需要 10 分钟才能运行完毕。

  1. 使用以下命令在 Cloud TPU 上登录 Hugging Face:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        pip3 install "huggingface_hub[cli]"
        huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. 运行模型训练:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \
        --zone ${ZONE} \
        --worker=all \
        --command='
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=PROFILE_PATH
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

PyTorch/XLA 问题排查

如果您在上一部分中设置了用于调试的可选变量,则模型的配置文件将存储在变量 PROFILE_LOGDIR 指定的位置。您可以提取存储在此位置的 xplane.pb 文件,并使用 tensorboard 按照 TensorBoard 说明在浏览器中查看配置文件。如果 PyTorch/XLA 的运行情况不符合预期,请参阅问题排查指南,其中提供了有关调试、性能分析和优化模型的建议。

在 v6e 上进行 DLRM DCN v2 训练

本教程介绍如何在 Cloud TPU v6e 上训练 DLRM DCN v2 模型。您需要预配 TPU v6e,其中包含 64、128 或 256 个芯片。

如果您在多主机 TPU 上运行,请运行以下命令,使用适当的 TensorFlow 版本重置 tpu-runtime。如果您在单主机 TPU 上运行,则无需运行以下两个命令。

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

使用 SSH 连接到 worker-0

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

设置 Cloud TPU 名称

export TPU_NAME=${TPU_NAME}

运行 DLRM v2

将以下代码段复制到名为 script.sh 的文件中:

pip install --user setuptools==65.5.0

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

如果您在 GKE 上运行 TensorFlow,请使用以下命令安装 TensorFlow Cloud TPU wheel 和 libtpu:

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

设置以下标志,这些标志对于运行推荐工作负载(例如 DLRM DCN)至关重要:

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

运行 script.sh

chmod +x script.sh
./script.sh

基准测试结果

以下部分包含在 v6e 上针对 DLRM DCN v2 和 MaxDiffusion 进行的基准测试结果。

DLRM DCN v2

DLRM DCN v2 训练脚本在不同规模下运行。请参阅下表中的吞吐量。

v6e-64 v6e-128 v6e-256
训练步骤 7000 7000 7000
全局批量大小 131072 262144 524288
吞吐量(示例/秒) 2975334 5111808 10066329

MaxDiffusion

我们在 v6e-4、v6e-16 和两个 v6e-16 上运行了 MaxDiffusion 的训练脚本。请参阅下表中的吞吐量。

v6e-4 v6e-16 两个 v6e-16
训练步骤 0.069 0.073 0.13
全局批量大小 8 32 64
吞吐量(示例/秒) 115.9 438.4 492.3