Trillium (v6e) 簡介
在本說明文件、TPU API 和記錄中,v6e 是指 Trillium。v6e 代表 Google 第 6 代 TPU。
每個 Pod 都有 256 個晶片,因此 v6e 架構與 v5e 有許多相似之處。這個系統經過最佳化,適用於 Transformer、文字轉圖片和卷積類神經網路 (CNN) 的訓練、微調和服務。
如要進一步瞭解 v6e 系統架構和設定,請參閱「TPU v6e」。
這份簡介文件著重於使用 JAX 或 PyTorch 架構訓練及提供模型服務的程序。使用各個架構時,您可以透過佇列資源或 GKE 佈建 TPU。您可以使用 XPK 或 GKE 指令設定 GKE。
使用 v6e 訓練或提供模型的一般程序
- 準備 Google Cloud 專案
- 確保容量
- 佈建 Cloud TPU 環境
- 執行模型訓練或推論工作負載
準備專案 Google Cloud
使用 Cloud TPU 前,請先完成下列步驟:
- 建立 Google Cloud 帳戶和專案並啟用計費功能
- 安裝 Google Cloud CLI Alpha 版元件
- 啟用 Cloud TPU API
- 建立 Cloud TPU 服務代理
- 建立 Cloud TPU 服務帳戶並授予權限
詳情請參閱「設定 Cloud TPU 環境」。
安全容量
如要申請 Cloud TPU v6e 配額,或對容量有任何疑問,請與Google Cloud 支援團隊聯絡。
佈建 Cloud TPU 環境
您可以使用 GKE、GKE 和 XPK (GKE 的包裝函式 CLI 工具),或以佇列資源的形式,佈建及管理 v6e Cloud TPU。
必要條件
- 確認專案有足夠的
TPUS_PER_TPU_FAMILY
配額,這項配額會指定您可在 Google Cloud專案中存取的晶片數量上限。 - v6e 已使用下列設定進行測試:
- Python
3.10
以上版本 - 夜間軟體版本:
- 每晚 JAX
0.4.32.dev20240912
- 每晚 LibTPU
0.1.dev20240912+nightly
- 每晚 JAX
- 穩定版軟體:
- JAX + JAX Lib 0.4.37 版
- Python
確認專案有足夠的配額,可供下列項目使用:
- Cloud TPU VM 配額
- IP 位址配額
Hyperdisk Balanced 的配額,以及您想使用的任何其他磁碟類型
如果您使用 GKE 搭配 XPK,請參閱「使用者或服務帳戶的 Cloud Console 權限」,瞭解執行 XPK 時所需的權限。
建立環境變數
在 Cloud Shell 中建立下列環境變數:
export NODE_ID=your-tpu-name export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v6e-16 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id export VALID_DURATION=your-duration # Additional environment variable needed for Multislice: export NUM_SLICES=number-of-slices # Use a custom network for better performance as well as to avoid having the default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
指令旗標說明
變數 | 說明 |
NODE_ID | 排入佇列的資源要求獲得分配時,系統會建立 Cloud TPU,並指派使用者 ID。 |
PROJECT_ID | Google Cloud 專案名稱。使用現有專案或建立新專案。 詳情請參閱「設定 Google Cloud 專案」。 |
ZONE | 如要瞭解支援的區域,請參閱「Cloud TPU 地區和區域」文件。 |
ACCELERATOR_TYPE | 請參閱「加速器類型」。 |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | 這是服務帳戶的電子郵件地址,您可以在 Google Cloud 控制台 ->「IAM」>「服務帳戶」中找到
例如: |
NUM_SLICES | 要建立的切片數量 (僅限多切片)。 |
QUEUED_RESOURCE_ID | 使用者指派的佇列資源要求文字 ID。 |
VALID_DURATION | 排隊資源要求的有效期間。 |
NETWORK_NAME | 要使用的次要網路名稱。 |
NETWORK_FW_NAME | 要使用的次要網路防火牆名稱。 |
提升網路效能
如要獲得最佳效能,請使用 MTU (最大傳輸單位) 為 8,896 的網路。
根據預設,虛擬私有雲 (VPC) 只會提供 1,460 位元組的 MTU,這會導致網路效能不佳。您可以將虛擬私有雲網路的 MTU 設為 1,300 到 8,896 位元組 (含) 之間的任何值。常見的自訂 MTU 大小為 1,500 個位元組 (標準乙太網路) 或 8,896 個位元組 (最大可能值)。詳情請參閱「有效的虛擬私有雲網路 MTU 大小」。
如要進一步瞭解如何變更現有或預設網路的 MTU 設定,請參閱「變更虛擬私有雲網路的 MTU 設定」。
下列範例會建立 MTU 為 8,896 的網路。
export RESOURCE_NAME=your-resource-name export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \ --allow tcp,icmp,udp --project=${PROJECT_ID}
使用多 NIC (多配量選項)
使用 Multislice 環境時,次要子網路需要下列環境變數。
export NETWORK_NAME_2=${RESOURCE_NAME} export SUBNET_NAME_2=${RESOURCE_NAME} export FIREWALL_RULE_NAME=${RESOURCE_NAME} export ROUTER_NAME=${RESOURCE_NAME}-network-2 export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2 export REGION=your-region
使用下列指令,為網路和子網路建立自訂 IP 路由。
gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
--bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
--network=${NETWORK_NAME_2} \
--range=10.10.0.0/18 --region=${REGION} \
--project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
--network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
--project=${PROJECT_ID} \
--network=${NETWORK_NAME_2} \
--region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
--router=${ROUTER_NAME} \
--region=${REGION} \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project=${PROJECT_ID} \
--enable-logging
建立多重網路切片後,您可以設定 XPK 叢集,並在 XPK 工作負載建立指令中新增 --command ifconfig
旗標,驗證是否同時使用兩個網路介面卡 (NIC)。
使用下列 workload create
指令,在 Google Cloud 控制台記錄中顯示 ifconfig
指令的輸出內容,並確認 eth0 和 eth1 的 mtu=8896。
python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command "ifconfig"
如要啟用偵錯記錄或使用 Vertex AI TensorBoard,請在指令中加入下列選用引數:
--enable-debug-logs \ --use-vertex-tensorboard
確認 eth0 和 eth1 的 mtu 都是 8,896。如要確認多重 NIC 是否正在執行,請將 --command ifconfig
旗標新增至 XPK 工作負載建立指令。在 Google Cloud 控制台記錄
中檢查該 XPK 工作負載的輸出內容,並確認 eth0 和 eth1 的 mtu 都是 8,896。
改善 TCP 設定
如果您是使用佇列資源介面建立 Cloud TPU,可以執行下列指令來提高 TCP 接收緩衝區限制,藉此提升網路效能。
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "${PROJECT_ID}" \ --zone "${ZONE}" \ --node=all \ --worker=all \ --command=' sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'
使用排入佇列的資源佈建
您可以使用排入佇列的資源建立 Cloud TPU v6e。排隊等候資源可讓您在容量可用時取得容量。您可以視需要指定要求應填寫的開始和結束時間。詳情請參閱「管理已加入佇列的資源」。
透過 GKE 或 XPK 佈建 v6e Cloud TPU
如果您使用 v6e 搭配 GKE 指令,可以透過 Kubernetes 指令或 XPK 佈建 Cloud TPU,並訓練或提供模型。如要瞭解如何在 GKE 叢集中規劃 Cloud TPU 設定,請參閱「規劃 GKE 中的 Cloud TPU」。下列各節提供指令,可建立支援單一 NIC 和多個 NIC 的 XPK 叢集。
建立支援單一 NIC 的 XPK 叢集
export CLUSTER_NAME=xpk-cluster-name export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network=${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
指令旗標說明
變數 | 說明 |
CLUSTER_NAME | 使用者指派的 XPK 叢集名稱。 |
PROJECT_ID | Google Cloud 專案名稱。使用現有專案或建立新專案。 詳情請參閱「設定 Google Cloud 專案」。 |
ZONE | 如要瞭解支援的區域,請參閱「Cloud TPU 地區和區域」文件。 |
TPU_TYPE | 請參閱「加速器類型」。 |
NUM_SLICES | 要建立的切片數量 |
CLUSTER_ARGUMENTS | 要使用的網路和子網路。
例如: |
NUM_SLICES | 要建立的切片數量。 |
NETWORK_NAME | 要使用的次要網路名稱。 |
NETWORK_FW_NAME | 要使用的次要網路防火牆名稱。 |
建立支援多個 NIC 的 XPK 叢集
export CLUSTER_NAME=xpk-cluster-name export REGION=your-region export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \ --network=${NETWORK_NAME_1} \ --range=10.11.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_1} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_1} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \ --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \ --create-vertex-tensorboard
指令旗標說明
變數 | 說明 |
CLUSTER_NAME | 使用者指派的 XPK 叢集名稱。 |
PROJECT_ID | Google Cloud 專案名稱。使用現有專案或建立新專案。 詳情請參閱「設定 Google Cloud 專案」。 |
ZONE | 如要瞭解支援的區域,請參閱「Cloud TPU 地區和區域」文件。 |
TPU_TYPE | 請參閱「加速器類型」。 |
NUM_SLICES | 要建立的切片數量 |
CLUSTER_ARGUMENTS | 要使用的網路和子網路。
例如: |
NODE_POOL_ARGUMENTS | 要使用的額外節點網路。
例如: |
NUM_SLICES | 要建立的切片數量 (僅限多切片)。 |
NETWORK_NAME | 要使用的次要網路名稱。 |
NETWORK_FW_NAME | 要使用的次要網路防火牆名稱。 |
設定架構
本節說明使用 JAX 和 PyTorch 架構訓練機器學習模型的一般設定程序。如果您使用 GKE,可以透過 XPK 或 Kubernetes 指令設定架構。
設定 JAX
本節提供設定說明,介紹如何在 GKE 上執行 JAX 工作負載 (無論是否使用 XPK),以及如何使用已加入佇列的資源。
使用 GKE 設定 JAX
單一主機上的單一切片
以下範例使用 Kubernetes YAML 檔案,設定 2x2 單一主機節點集區。
apiVersion: v1
kind: Pod
metadata:
name: tpu-pod-jax-v6e-a
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
containers:
- name: tpu-job
image: python:3.10
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
成功完成後,您應該會在 GKE 記錄檔中看到以下訊息:
Total TPU chips: 4
多主機上的單一切片
以下範例使用 Kubernetes YAML 檔案,設定 4x4 多主機節點集區。
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-available-chips
spec:
backoffLimit: 0
completions: 4
parallelism: 4
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: tpu-job
image: python:3.10
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
成功完成後,您應該會在 GKE 記錄檔中看到以下訊息:
Total TPU chips: 16
多主機上的多配量
以下範例使用 Kubernetes YAML 檔案,設定兩個 4x4 多主機節點集區。
您必須先安裝 JobSet 0.2.3 以上版本。
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: multislice-job
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
maxRestarts: 4
replicatedJobs:
- name: slice
replicas: 2
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
hostNetwork: true
containers:
- name: jax-tpu
image: python:3.10
ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 8431
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
limits:
google.com/tpu: 4
requests:
google.com/tpu: 4
成功完成後,您應該會在 GKE 記錄檔中看到以下訊息:
Total TPU chips: 32
詳情請參閱 GKE 說明文件中的「執行多切片工作負載」。
如要提升效能,請啟用 hostNetwork。
多個 NIC
如要使用下列多重 NIC 資訊清單,您必須設定網路。詳情請參閱「為 Kubernetes Pod 設定多網路支援功能」。
如要在 GKE 中使用多重 NIC,您必須在 Kubernetes Pod 資訊清單中加入一些額外的註解。
以下是無 TPU 的多 NIC 工作負載範例資訊清單。
apiVersion: v1
kind: Pod
metadata:
name: sample-netdevice-pod-1
annotations:
networking.gke.io/default-interface: 'eth0'
networking.gke.io/interfaces: |
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"netdevice-network"}
]
spec:
containers:
- name: sample-netdevice-pod
image: busybox
command: ["sleep", "infinity"]
ports:
- containerPort: 80
restartPolicy: Always
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
如果您使用 exec
指令連線至 Kubernetes Pod,應該會看到使用下列程式碼的額外 NIC:
$ kubectl exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
inet 172.24.0.4/32 scope global eth1
valid_lft forever preferred_lft forever
使用 XPK 透過 GKE 設定 JAX
如要使用 GKE 和 XPK 設定 JAX,請參閱 XPK README。
如要使用 MaxText 設定及執行 XPK,請參閱「如何執行 MaxText」。
使用排入佇列的資源設定 JAX
使用 gcloud alpha compute tpus tpu-vm ssh
指令,在配量或多個配量的所有 Cloud TPU VM 上同時安裝 JAX。如果是 Multislice,請新增 --node=all
旗標。
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='
pip install -U --pre jax jaxlib libtpu-nightly requests -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
您可以執行下列指令,檢查配量中可用的 Cloud TPU 核心數量,並測試所有項目是否已正確安裝:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='
python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
在 v6e-16 切片上執行時,輸出內容會與下列內容類似:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count()
會顯示指定切片中的晶片總數。
jax.local_device_count()
表示這個切片中單一 VM 可存取的晶片數量。
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 &&
pip install setuptools==59.6.0 &&
pip install -r requirements.txt && pip install .'
排解 JAX 設定問題
一般來說,建議您在 GKE 工作負載資訊清單中啟用詳細記錄功能。然後將記錄檔提供給 GKE 支援團隊。
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
錯誤訊息
no endpoints available for service 'jobset-webhook-service'
這表示 Jobset 安裝有誤。檢查 jobset-controller-manager 部署作業的 Kubernetes Pod 是否正在執行。詳情請參閱 JobSet 疑難排解說明文件。
TPU initialization failed: Failed to connect
請確認 GKE 節點版本為 1.30.4-gke.1348000 以上版本 (不支援 GKE 1.31)。
PyTorch 設定
本節說明如何開始在 v6e 上使用 PyTorch/XLA 的 PJRT。建議使用 Python 3.10。
使用 GKE 和 XPK 設定 PyTorch
您可以使用下列 Docker 容器搭配 XPK,其中已安裝 PyTorch 依附元件:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
如要建立 XPK 工作負載,請使用下列指令:
python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \ --workload ${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone ${ZONE} \ --project ${PROJECT_ID} \ --enable-debug-logs \ --command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
使用 --base-docker-image
會建立新的 Docker 映像檔,並將目前的工作目錄建構到新的 Docker 中。
使用排入佇列的資源設定 PyTorch
請按照下列步驟,使用已加入佇列的資源安裝 PyTorch,並在 v6e 上執行小型指令碼。
使用 SSH 存取 VM 並安裝依附元件
使用下列指令在所有 Cloud TPU VM 上安裝依附元件。如為 Multislice,請新增 --worker=all
標記:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
sudo apt update && sudo apt install -y python3-pip libopenblas-base && \
pip3 install torch~=2.6.0 "torch_xla[tpu]~=2.6.0" -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
改善模型效能,並進行大規模的頻繁分配
對於有大量頻繁分配的模型,與預設的 malloc
函式實作相比,使用 tcmalloc
函式可大幅提升效能,因此 Cloud TPU VM 上使用的預設 malloc
函式為 tcmalloc
。不過,視工作負載而定 (例如,DLRM 的嵌入資料表分配量非常大),tcmalloc
函式可能會導致速度變慢,在這種情況下,您可以嘗試改用預設的 malloc
函式取消設定下列變數:
unset LD_PRELOAD
使用 Python 指令碼在 v6e VM 上執行計算
使用下列指令執行指令碼,建立兩個張量、將兩者相加,然後列印結果:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--worker all \
--command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
這麼做會產生類似以下內容的輸出結果:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
搭配 SkyPilot 的 v6e
您可以搭配 SkyPilot 使用 Cloud TPU v6e。請按照下列步驟,將與 v6e 相關的位置和價格資訊新增至 SkyPilot。詳情請參閱 SkyPilot TPU v6e 範例。
推論教學課程
下列教學課程說明如何在 Cloud TPU v6e 上執行推論:
訓練範例
下列各節提供在 Cloud TPU v6e 上訓練 MaxText、MaxDiffusion 和 PyTorch 模型的範例。
在 v6e Cloud TPU VM 上訓練 MaxText 和 MaxDiffusion
以下各節將說明 MaxText 和 MaxDiffusion 模型的訓練生命週期。
一般來說,高階步驟如下:
- 建構工作負載基本映像檔。
- 使用 XPK 執行工作負載。
- 為工作負載建構訓練指令。
- 部署工作負載。
- 追蹤工作負載並查看指標。
- 如不需要,請刪除 XPK 工作負載。
- 不再需要 XPK 叢集時,請將其刪除。
建構基本映像檔
安裝 MaxText 或 MaxDiffusion,並建構 Docker 映像檔:
複製要使用的存放區,然後變更為存放區的目錄:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
將 Docker 設為使用 Google Cloud CLI:
gcloud auth configure-docker
使用下列指令或 JAX Stable Stack 建構 Docker 映像檔。 如要進一步瞭解 JAX Stable Stack,請參閱「使用 JAX Stable Stack 建構 Docker 映像檔」。
MaxText:
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
MaxDiffusion:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
在有效的 gcloud CLI 設定中設定專案 ID:
gcloud config set project ${PROJECT_ID}
如果從沒有在本機建構映像檔的機器啟動工作負載,請上傳映像檔。
設定
CLOUD_IMAGE_NAME
環境變數:export CLOUD_IMAGE_NAME=${USER}_runner
上傳圖片:
bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
使用 XPK 執行工作負載
如果未使用 MaxText 設定的預設值或 MaxDiffusion,請設定下列環境變數:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
建構模型指令碼。這個指令碼會在後續步驟中複製為訓練指令。
請先不要執行模型指令碼。
MaxText
MaxText 是以純 Python 和 JAX 編寫的開放原始碼 LLM,具備高效能和高擴充性,適用於 TPU 和 GPU,可進行訓練和推論。 Google Cloud
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma 是一系列開放權重的 LLM,由 Google DeepMind 根據 Gemini 研究和技術開發而成。
python3 -m MaxText.train MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral 是由 Mistral AI 開發的頂尖 AI 模型,採用稀疏專家混合 (MoE) 架構。
python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama 是 Meta 開發的開放權重 LLM 系列。
如需瞭解如何在 PyTorch 上執行 Llama3,請參閱 torchprime 存放區中的 torch_xla 模型。
MaxDiffusion
MaxDiffusion 是一系列以純 Python 和 JAX 編寫的各種延遲擴散模型參考實作,可在 XLA 裝置上執行,包括 Cloud TPU 和 GPU。Stable Diffusion 是一種潛在文字轉圖像模型,可根據任何文字輸入生成逼真的圖像。
您需要安裝特定 Git 分支,才能執行 MaxDiffusion,如下列
git clone
指令所示。訓練指令碼:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 && pip install -r requirements.txt && pip install . && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR} && python src/maxdiffusion/train_sdxl.py src/maxdiffusion/configs/base_xl.yml revision=refs/pr/95 activations_dtype=bfloat16 weights_dtype=bfloat16 resolution=1024 per_device_batch_size=1 output_dir=${OUT_DIR} jax_cache_dir=${OUT_DIR}/cache_dir/ max_train_steps=200 attention=flash run_name=sdxl-ddp-v6e
匯出下列變數:
export CLUSTER_NAME=CLUSTER_NAME export ACCELERATOR_TYPE=ACCELERATOR_TYPE export NUM_SLICES=NUM_SLICES export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT
環境變數說明
變數 說明 CLUSTER_NAME
XPK 叢集的名稱。 ACCELERATOR_TYPE
請參閱「加速器類型」。 NUM_SLICES
TPU 配量數量。 YOUR_MODEL_SCRIPT
要以訓練指令執行的模型指令碼。 使用上一步建立的指令碼執行模型。 您必須指定
--base-docker-image
旗標來使用 MaxText 基礎映像檔,或指定--docker-image
旗標和要使用的映像檔。選用:您可以加入
--enable-debug-logs
標記來啟用偵錯記錄功能。詳情請參閱「在 MaxText 上偵錯 JAX」。選用:您可以建立 Vertex AI 實驗,並加入
--use-vertex-tensorboard
標記,將資料上傳至 Vertex AI TensorBoard。詳情請參閱「使用 Vertex AI 監控 MaxText 上的 JAX」。python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command="${YOUR_MODEL_SCRIPT}"
輸出內容包含追蹤工作負載的連結。 開啟連結並點選「記錄」分頁,即可即時追蹤工作負載。
在 MaxText 上偵錯 JAX
使用補充 XPK 指令,診斷叢集或工作負載無法執行的原因:
- XPK 工作負載清單
- XPK 檢查工具
- 建立 XPK 工作負載時,使用
--enable-debug-logs
旗標在工作負載記錄中啟用詳細記錄
使用 Vertex AI 監控 MaxText 上的 JAX
如要使用 TensorBoard,您的 Google Cloud 使用者帳戶必須具備aiplatform.user
角色。執行下列指令來授予這個角色:
gcloud projects add-iam-policy-binding your-project-id \ --member='user:your-email' \ --role='roles/aiplatform.user'
透過 Vertex AI 管理的 TensorBoard 查看純量和剖析資料。
將您使用的區域資源管理 (CRUD) 要求數從 600 提高至 5000。如果工作負載較小,使用的 VM 少於 16 個,可能不會有問題。
安裝 Vertex AI 的依附元件,例如
cloud-accelerator-diagnostics
:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
使用
--create-vertex-tensorboard
旗標建立 XPK 叢集,如「建立 Vertex AI TensorBoard」一文所述。您也可以在現有叢集上執行這項指令。使用
--use-vertex-tensorboard
旗標和選用的--experiment-name
旗標執行 XPK 工作負載時,請建立 Vertex AI 實驗。如需完整步驟清單,請參閱「建立 Vertex AI 實驗,將資料上傳至 Vertex AI TensorBoard」。
記錄會包含 Vertex AI TensorBoard 的連結,類似於下列連結:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
您也可以在 Google Cloud 控制台中找到 Vertex AI TensorBoard 連結。 前往 Google Cloud 控制台中的 Vertex AI Experiments。從下拉式選單中選取適當的地區。
TensorBoard 目錄也會寫入您以 ${BASE_OUTPUT_DIR}
指定的 Cloud Storage bucket。
刪除 XPK 工作負載
使用 xpk workload delete
指令,根據工作前置字元或工作狀態刪除一或多個工作負載。如果您傳送的 XPK 工作負載不再需要執行,或是工作停滯在佇列中,這個指令就非常實用。
刪除 XPK 叢集
使用 xpk cluster delete
指令刪除叢集:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone=${ZONE} --project=${PROJECT_ID}
在 v6e Cloud TPU VM 上訓練 Llama 和 PyTorch/XLA
本教學課程說明如何使用 WikiText 資料集,在 Cloud TPU v6e 上使用 PyTorch/XLA 訓練 Llama 模型。
存取 Hugging Face 和 Llama 3 模型
您需要 Hugging Face 使用者存取權杖,才能執行本教學課程。如要瞭解如何建立使用者存取權杖,請參閱 Hugging Face 使用者存取權杖說明文件。
此外,您也需要取得在 Hugging Face 存取 Llama-3-8B 模型的權限。如要取得存取權,請前往 HuggingFace 上的 Meta-Llama-3-8B 模型,然後要求存取權。
建立 Cloud TPU VM
建立具有 8 個晶片的 Cloud TPU v6e,以執行本教學課程。
設定環境變數:
export NODE_ID=your-tpu-name export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v6e-8 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id export VALID_DURATION=your-duration
建立 Cloud TPU VM:
gcloud alpha compute tpus tpu-vm create ${NODE_ID} --version=${RUNTIME_VERSION} \ --accelerator-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID}
安裝
安裝 Hugging Face Transformers 的 pytorch-tpu/transformers
fork 和依附元件。本教學課程已使用下列範例中的依附元件版本進行測試:
torch
:相容於 2.5.0torch_xla[tpu]
:相容於 2.5.0jax
:0.4.33jaxlib
:0.4.33
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'
設定模型設定
下一節「執行模型」中的訓練指令會使用兩個 JSON 設定檔,定義模型參數和 Fully Sharded Data Parallel (FSDP) 設定。透過 FSDP 分片,您可以在訓練時使用較大的批次大小,方法是將模型權重分散到多個 TPU。使用較小的模型進行訓練時,可能只要使用資料平行處理,並在每部裝置上複製權重即可。如要進一步瞭解如何在 PyTorch/XLA 中跨裝置分片張量,請參閱 PyTorch/XLA SPMD 使用指南。
建立模型參數設定檔。以下是 Llama-3-8B 的模型參數設定。如要使用其他模型,請在 Hugging Face 上尋找設定。 例如,請參閱 Llama-2-7B 設定。
cat > llama-config.json << EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
建立 FSDP 設定檔:
cat > fsdp-config.json << EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
如要進一步瞭解 FSDP,請參閱 FSDPv2。
使用下列指令,將設定檔上傳至 Cloud TPU VM:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${NODE_ID}:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
執行模型
使用您在前一節建立的設定檔,執行 run_clm.py
指令碼,在 WikiText 資料集上訓練 Llama-3-8B 模型。訓練指令碼在 Cloud TPU v6e-8 上執行約需 10 分鐘。
在 Cloud TPU 上使用下列指令登入 Hugging Face:
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
執行模型訓練:
gcloud alpha compute tpus tpu-vm ssh ${NODE_ID} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
排解 PyTorch/XLA 問題
如果您在上一節中設定了用於偵錯的選用變數,模型的設定檔會儲存在變數 PROFILE_LOGDIR
指定的位置。您可以擷取儲存在這個位置的 xplane.pb
檔案,並使用 tensorboard
按照 TensorBoard 指示,在瀏覽器中查看設定檔。
如果 PyTorch/XLA 未如預期運作,請參閱疑難排解指南,瞭解如何偵錯、分析及最佳化模型。
基準化結果
以下章節包含 MaxDiffusion 在 v6e 上的基準測試結果。
MaxDiffusion
我們在 v6e-4、v6e-16 和兩個 v6e-16 上執行 MaxDiffusion 的訓練指令碼。請參閱下表中的輸送量。
v6e-4 | v6e-16 | 兩個 v6e-16 | |
---|---|---|---|
訓練步驟 | 0.069 | 0.073 | 0.13 |
全域批次大小 | 8 | 32 | 64 |
處理量 (每秒樣本數) | 115.9 | 438.4 | 492.3 |