使用 TPU v6e 訓練模型

本文將逐步說明如何在 Cloud TPU v6e (又稱 Trillium) 上訓練模型,涵蓋環境設定、效能最佳化,以及使用 JAX 和 PyTorch/XLA 的實用訓練範例。

TPU v6e (又稱 Trillium) 是 Google 的第 6 代 TPU,在所有技術介面 (例如 API 和記錄) 和本文件中,Trillium 都會稱為 v6e。每個 Pod 都有 256 個晶片,因此 TPU v6e 的架構與 v5e 有許多相似之處。TPU v6e 經過最佳化,適用於 Transformer、文字轉圖片和卷積類神經網路 (CNN) 的訓練、微調和服務。如要進一步瞭解 TPU v6e 系統架構和設定,請參閱 TPU v6e

如要瞭解如何在 Cloud TPU v6e 上執行推論作業,請參閱下列教學課程:

事前準備

開始前,請先完成下列事項:

  • 建立 Google Cloud 帳戶和專案並啟用計費功能
  • 安裝 Google Cloud CLI Alpha 版元件
  • 啟用 Cloud TPU API
  • 建立 Cloud TPU 服務代理
  • 建立 Cloud TPU 服務帳戶並授予權限

詳情請參閱「設定 Cloud TPU 環境」。

確認配額和權限

確認專案具有下列配額:

如果您將 GKE 與 XPK 搭配使用,則需要在 Google Cloud 控制台中取得額外權限。詳情請參閱「Google Cloud 控制台 需要的權限」。

佈建 TPU

您可以使用下列方法佈建及管理 TPU v6e:

  • GKE:您可以透過 GKE 佈建及管理 TPU,做為容器化機器學習工作負載的加速器集區。詳情請參閱「關於 GKE 中的 TPU」。
  • GKE 和 XPK:XPK 是一項指令列工具,可簡化在 GKE 上建立叢集和執行工作負載的作業。這項服務專為 ML 從業人員設計,可讓他們佈建 TPU 並執行訓練作業,不必具備深入的 Kubernetes 專業知識。詳情請參閱 XPK GitHub 存放區
  • Cloud TPU 佇列資源:您可以要求佇列資源,在 TPU 容量可用時進行佈建。非常適合可等待佇列的批次工作和容錯工作負載。你可以為要求指定時間範圍。詳情請參閱「管理已加入佇列的資源」。

透過 GKE 和 XPK 佈建 v6e Cloud TPU

如果您使用 v6e 搭配 GKE 指令,可以透過 Kubernetes 指令或 XPK 佈建 Cloud TPU,並訓練或提供模型。請參閱「規劃 GKE 中的 Cloud TPU」,瞭解如何在 GKE 叢集中規劃 Cloud TPU 設定。下列各節提供指令,可建立支援單一 NIC 和多個 NIC 的 XPK 叢集。

建立支援單一 NIC 的 XPK 叢集

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --create-vertex-tensorboard

指令旗標說明

變數 說明
CLUSTER_NAME 使用者指派的 XPK 叢集名稱。
PROJECT_ID Google Cloud 專案名稱。使用現有專案或建立新專案。 詳情請參閱「設定 Google Cloud 專案」。
ZONE 如要瞭解支援的區域,請參閱「Cloud TPU 地區和區域」文件。
TPU_TYPE 請參閱「加速器類型」。
NUM_SLICES 要建立的切片數量
CLUSTER_ARGUMENTS 要使用的網路和子網路。

例如:--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES 要建立的切片數量。
NETWORK_NAME 要使用的次要網路名稱。
NETWORK_FW_NAME 要使用的次要網路防火牆名稱。

建立支援多個 NIC 的 XPK 叢集

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_1} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
   --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
   --create-vertex-tensorboard

指令旗標說明

變數 說明
CLUSTER_NAME 使用者指派的 XPK 叢集名稱。
PROJECT_ID Google Cloud 專案名稱。使用現有專案或建立新專案。 詳情請參閱「設定 Google Cloud 專案」。
ZONE 如要瞭解支援的區域,請參閱「Cloud TPU 地區和區域」文件。
TPU_TYPE 請參閱「加速器類型」。
NUM_SLICES 要建立的切片數量
CLUSTER_ARGUMENTS 要使用的網路和子網路。

例如:--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS 要使用的額外節點網路。

例如:--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES 要建立的切片數量 (僅限多切片)。
NETWORK_NAME 要使用的次要網路名稱。
NETWORK_FW_NAME 要使用的次要網路防火牆名稱。

設定 JAX 或 PyTorch

下列資源說明如何在 Cloud TPU 上設定 JAX 或 PyTorch,具體做法取決於您使用的佈建和管理方法:

如要使用 MaxText 設定及執行 XPK,請參閱「使用 XPK 大規模執行 MaxText 」。

提升網路效能

本節說明如何設定最大傳輸單元 (MTU)、在多重切片環境中使用多個 NIC,以及改善 TCP 設定,藉此提升網路效能。

設定 MTU

如要獲得最佳網路效能,請使用 MTU (最大傳輸單位) 為 8,896 的網路。

根據預設,虛擬私有雲 (VPC) 僅提供 1,460 位元組的 MTU,這會導致網路效能不佳。您可以將虛擬私有雲網路的 MTU 設為 1,300 到 8,896 位元組 (含) 之間的任何值。常見的自訂 MTU 大小為 1,500 個位元組 (標準乙太網路) 或 8,896 個位元組 (最大可能值)。詳情請參閱「有效的虛擬私有雲網路 MTU 大小」。

如要進一步瞭解如何變更現有或預設網路的 MTU 設定,請參閱「變更虛擬私有雲網路的 MTU 設定」。

下列範例會建立 MTU 為 8,896 的網路,以及允許網路內 TCP、ICMP 和 UDP 流量的相應防火牆規則。

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
    --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
    --allow tcp,icmp,udp --project=${PROJECT_ID}

your-resource-name 替換為網路和防火牆的基本名稱。

使用多 NIC 選項進行多配量

如果您使用 Multislice 環境,請設定下列環境變數,這些變數是次要子網路的必要條件:

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

使用下列指令,為網路和子網路建立自訂 IP 路由。

  1. 建立次要網路。

    gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
    --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
    
  2. 為次要網路建立子網路。

    gcloud compute networks subnets create ${SUBNET_NAME_2} \
    --network=${NETWORK_NAME_2} \
    --range=10.10.0.0/18 --region=${REGION} \
    --project=${PROJECT_ID}
    
  3. 建立防火牆規則,允許新子網路內的流量。

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
    --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
    
  4. 為次要網路建立 Cloud Router。

    gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_2} \
    --region=${REGION}
    
  5. 為 Cloud Router 建立 NAT 設定。

    gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
    

建立多重網路切片後,您可以設定 XPK 叢集,並在 XPK 工作負載建立指令中新增 --command ifconfig 旗標,驗證是否同時使用兩個網路介面卡 (NIC)。

  1. 使用下列 workload create 指令,在 Google Cloud 控制台記錄中顯示 ifconfig 指令的輸出內容,並確認 eth0 和 eth1 的 MTU 都設為 8,896。

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --command "ifconfig"

    如要啟用偵錯記錄或使用 Vertex AI TensorBoard,請在指令中加入下列選用引數:

    --enable-debug-logs \
    --use-vertex-tensorboard
  2. 在 Google Cloud 控制台記錄中檢查 XPK 工作負載的輸出內容,確認 eth0 和 eth1 的 MTU 都設為 8,896。

改善 TCP 設定

如果您使用排入佇列的資源佈建 Cloud TPU,可以執行下列指令,提高 TCP 接收緩衝區限制,藉此提升網路效能。

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
    --project "${PROJECT_ID}" \
    --zone "${ZONE}" \
    --node=all \
    --worker=all \
    --command='
    sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

最佳化記憶體分配效能

在 Cloud TPU VM 上,系統預設會使用 tcmalloc 程式庫,改善模型效能,並頻繁分配大量記憶體。這項設定是透過 LD_PRELOAD 環境變數進行。

不過,對於某些工作負載 (例如具有非常大的嵌入資料表分配的 DLRM),tcmalloc 可能會導致速度變慢。在這種情況下,您可以在執行訓練指令碼前,取消設定殼層工作階段中的 LD_PRELOAD 變數,還原為標準 malloc 函式:

unset LD_PRELOAD

使用 SkyPilot

您可以搭配 SkyPilot 使用 Cloud TPU v6e。SkyPilot 是開放原始碼架構,可簡化執行、管理及調度 AI 工作負載的程序。你可以在 SkyPilot 中新增 v6e 相關位置和定價資訊。 詳情請參閱 SkyPilot TPU v6e 範例

訓練範例

下列各節提供在 Cloud TPU v6e 上訓練 MaxText、MaxDiffusion 和 PyTorch 模型的範例。

這些範例已使用下列軟體版本測試:

  • Python 3.10 以上版本
  • 夜間軟體版本:
    • 每晚 JAX 0.4.32.dev20240912
    • 每晚 LibTPU 0.1.dev20240912+nightly
  • 穩定版軟體:
    • JAX + JAX Lib of v0.4.37

在 Cloud TPU v6e 上訓練 MaxText 和 MaxDiffusion

以下各節將說明 MaxTextMaxDiffusion 模型的訓練生命週期。

一般來說,高階步驟如下:

  1. 建構工作負載基本映像檔。
  2. 使用 XPK 執行工作負載。
    1. 為工作負載建構訓練指令。
    2. 部署工作負載。
  3. 追蹤工作負載並查看指標。
  4. 如不需要,請刪除 XPK 工作負載。
  5. 不再需要 XPK 叢集時,請將其刪除。

建構基本映像檔

安裝 MaxText 或 MaxDiffusion,並建構 Docker 映像檔:

  1. 複製要使用的存放區,然後變更為存放區的目錄:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. 將 Docker 設為使用 Google Cloud CLI:

    gcloud auth configure-docker
    
  3. 使用下列指令或 JAX AI 圖片建構 Docker 映像檔。 如要進一步瞭解 JAX AI 圖片,請參閱「JAX AI 圖片」。

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. 在有效的 gcloud CLI 設定中設定專案 ID:

    gcloud config set project ${PROJECT_ID}
    
  5. 如果從沒有在本機建構映像檔的機器啟動工作負載,請上傳映像檔。

    1. 設定 CLOUD_IMAGE_NAME 環境變數:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. 上傳圖片:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

使用 XPK 執行工作負載

  1. 如果未使用 MaxText 設定的預設值MaxDiffusion,請設定下列環境變數:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. 建構模型指令碼。這個指令碼會在後續步驟中複製為訓練指令。

    請先不要執行模型指令碼。

    MaxText

    MaxText 是以純 Python 和 JAX 編寫的開放原始碼 LLM,具備高效能和高擴充性,適用於 TPU 和 GPU,可進行訓練和推論。 Google Cloud

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma 是 Google DeepMind 開發的一系列開放權重 LLM,以 Gemini 研究和技術為基礎。

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral 是由 Mistral AI 開發的頂尖 AI 模型,採用稀疏混合專家 (MoE) 架構。

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama 是 Meta 開發的一系列開放權重 LLM。

    如需瞭解如何在 PyTorch 上執行 Llama3,請參閱 torchprime 存放區中的 torch_xla 模型

    MaxDiffusion

    MaxDiffusion 是一系列以純 Python 和 JAX 編寫的各種延遲擴散模型參考實作,可在 XLA 裝置上執行,包括 Cloud TPU 和 GPU。Stable Diffusion 是潛在文字轉圖像模型,可根據任何文字輸入生成逼真的圖像。

    您需要安裝特定 Git 分支,才能執行 MaxDiffusion,如下列訓練指令碼所示。

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    && pip install -r requirements.txt && pip install .
    && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR}
    && python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR} \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash \
        run_name=sdxl-ddp-v6e
    
  3. 匯出下列變數:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    環境變數說明

    變數 說明
    CLUSTER_NAME XPK 叢集的名稱。
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱 TPU 版本
    NUM_SLICES TPU 配量數量。
    YOUR_MODEL_SCRIPT 要以訓練指令執行的模型指令碼。
  4. 使用上一步建立的指令碼執行模型。您必須指定 --base-docker-image 旗標來使用 MaxText 基礎映像檔,或是指定 --docker-image 旗標和要使用的映像檔。

    您可以選擇新增下列選用旗標:

    python3 xpk.py workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --command="${YOUR_MODEL_SCRIPT}"

    輸出內容包含追蹤工作負載的連結。 開啟連結並點選「記錄」分頁,即可即時追蹤工作負載。

在 MaxText 上偵錯 JAX

使用補充 XPK 指令,診斷叢集或工作負載無法執行的原因:

使用 Vertex AI 監控 MaxText 上的 JAX

如要使用 TensorBoard,您的 Google Cloud 使用者帳戶必須具備aiplatform.user角色。執行下列指令來授予這個角色:

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

透過 Vertex AI 管理的 TensorBoard 查看純量和設定檔資料。

  1. 將您使用的區域資源管理 (CRUD) 要求數從 600 提高至 5000。如果工作負載較小,使用的 VM 少於 16 個,可能不會有問題。

  2. 安裝 Vertex AI 的依附元件,例如 cloud-accelerator-diagnostics

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. 使用 --create-vertex-tensorboard 旗標建立 XPK 叢集,如「建立 Vertex AI TensorBoard」一文所述。您也可以在現有叢集上執行這項指令。

  4. 使用 --use-vertex-tensorboard 旗標和選用的 --experiment-name 旗標執行 XPK 工作負載時,請建立 Vertex AI 實驗。如需完整步驟清單,請參閱「建立 Vertex AI 實驗,將資料上傳至 Vertex AI TensorBoard」。

記錄會包含 Vertex AI TensorBoard 的連結,類似於下列連結:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

您也可以在 Google Cloud 控制台中找到 Vertex AI TensorBoard 連結。 前往 Google Cloud 控制台中的 Vertex AI Experiments。從下拉式選單中選取適當的地區。

TensorBoard 目錄也會寫入您以 ${BASE_OUTPUT_DIR} 指定的 Cloud Storage bucket。

刪除 XPK 工作負載

使用 xpk workload delete 指令,根據工作前置字元或工作狀態刪除一或多個工作負載。如果您傳送的 XPK 工作負載不再需要執行,或是工作停滯在佇列中,這個指令就非常實用。

刪除 XPK 叢集

使用 xpk cluster delete 指令刪除叢集:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
    --zone=${ZONE} --project=${PROJECT_ID}

MaxDiffusion 基準化結果

我們在 v6e-4、v6e-16 和兩個 v6e-16 上執行 MaxDiffusion 的訓練指令碼。下表顯示測得的處理量。

v6e-4 v6e-16 兩個 v6e-16
訓練步驟 0.069 0.073 0.13
全域批次大小 8 32 64
處理量 (每秒樣本數) 115.9 438.4 492.3

在 Cloud TPU v6e 上使用 PyTorch/XLA 訓練 Llama 模型

本節說明如何使用 PyTorch/XLA 在 Cloud TPU v6e 上,透過 WikiText 資料集訓練 Llama 模型。

存取 Hugging Face 和 Llama 3 模型

您需要 Hugging Face 使用者存取權杖才能執行這個範例。如要瞭解如何建立使用者存取權杖,請參閱 Hugging Face 的使用者存取權杖說明文件

此外,您也需要取得在 Hugging Face 存取 Llama-3-8B 模型的權限。如要取得存取權,請前往 HuggingFace 上的 Meta-Llama-3-8B 模型,然後要求存取權。

建立 Cloud TPU VM

在本範例中,請建立具有 8 個晶片的 Cloud TPU v6e。

  1. 設定環境變數:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east1-d
    export ACCELERATOR_TYPE=v6e-8
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本

  2. 建立 Cloud TPU VM:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

安裝

安裝 Hugging Face Transformers 的 pytorch-tpu/transformers fork 和依附元件。這個範例已使用下列依附元件版本進行測試:

  • torch:與 2.5.0 相容
  • torch_xla[tpu]:與 2.5.0 相容
  • jax:0.4.33
  • jaxlib:0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

設定模型設定檔

下一節的訓練指令「執行模型」會使用兩個 JSON 設定檔,定義模型參數和完全分片資料平行 (FSDP) 設定。透過 FSDP 分片,您可以在訓練時使用較大的批次大小,方法是將模型權重分散到多個 TPU。使用較小的模型進行訓練時,可能只要使用資料平行處理,並在每部裝置上複製權重即可。如要進一步瞭解如何在 PyTorch/XLA 中跨裝置分片張量,請參閱 PyTorch/XLA SPMD 使用指南

  1. 建立模型參數設定檔。以下是 Llama-3-8B 的模型參數設定。如為其他模型,請在 Hugging Face 上尋找設定檔。舉例來說,請參閱 Llama-2-7B 設定

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. 建立 FSDP 設定檔:

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    如要進一步瞭解 FSDP,請參閱「Fully Sharded Data Parallel using SPMD 」。

  3. 使用下列指令,將設定檔上傳至 Cloud TPU VM:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

執行模型

使用您在前一節中建立的設定檔,執行 run_clm.py 指令碼,在 WikiText 資料集上訓練 Llama-3-8B 模型。訓練指令碼在 Cloud TPU v6e-8 上執行約需 10 分鐘。

  1. 在 Cloud TPU 上使用下列指令登入 Hugging Face:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. 執行模型訓練:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

排解 PyTorch/XLA 問題

如果您在上一節中設定了用於偵錯的選用變數,模型的設定檔會儲存在變數 PROFILE_LOGDIR 指定的位置。您可以擷取儲存在這個位置的 xplane.pb 檔案,並使用 tensorboard 依據 TensorBoard 指示,在瀏覽器中查看設定檔。

如果 PyTorch/XLA 未如預期運作,請參閱疑難排解指南,瞭解如何偵錯、分析及最佳化模型。