使用 TPU v5e 訓練模型

每個 Pod 的晶片數量較少 (256 個),因此 TPU v5e 經過最佳化調整,可成為 Transformer、文字轉圖片和卷積類神經網路 (CNN) 訓練、微調和服務的高價值產品。如要進一步瞭解如何使用 Cloud TPU v5e 進行服務,請參閱「使用 v5e 進行推論」。

如要進一步瞭解 Cloud TPU v5e TPU 硬體和設定,請參閱「TPU v5e」。

開始使用

以下各節說明如何開始使用 TPU v5e。

要求配額

如要使用 TPU v5e 進行訓練,您需要配額。隨選 TPU、預留 TPU 和 TPU Spot VM 的配額類型不同。如果您使用 TPU v5e 進行推論,則需要不同的配額。如要進一步瞭解配額,請參閱配額。如要申請 TPU v5e 配額,請與 Cloud 銷售團隊聯絡。

建立 Google Cloud 帳戶和專案

如要使用 Cloud TPU,您需要 Google Cloud 帳戶和專案。詳情請參閱「設定 Cloud TPU 環境」。

建立 Cloud TPU

最佳做法是使用 queued-resource create 指令,將 Cloud TPU v5e 佈建為佇列資源。詳情請參閱「管理已加入佇列的資源」。

您也可以使用 Create Node API (gcloud compute tpus tpu-vm create) 佈建 Cloud TPU v5e。詳情請參閱「管理 TPU 資源」。

如要進一步瞭解可用的 v5e 訓練設定,請參閱「Cloud TPU v5e 訓練類型」。

設定架構

本節說明使用 JAX 或 PyTorch 和 TPU v5e 訓練自訂模型的一般設定程序。

如需推論設定操作說明,請參閱 v5e 推論簡介

定義一些環境變數:

export PROJECT_ID=your_project_ID
export ACCELERATOR_TYPE=v5litepod-16
export ZONE=us-west4-a
export TPU_NAME=your_tpu_name
export QUEUED_RESOURCE_ID=your_queued_resource_id

設定 JAX

如果切片形狀大於 8 個晶片,則一個切片中會有數個 VM。在這種情況下,您需要使用 --worker=all 旗標,在單一步驟中於所有 TPU VM 上執行安裝作業,不必使用 SSH 分別登入各個 VM:

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

指令旗標說明

變數 說明
TPU_NAME TPU 的使用者指派文字 ID,會在排入佇列的資源要求獲得分配時建立。
PROJECT_ID Google Cloud 專案名稱。使用現有專案或建立新專案 前往 設定專案 Google Cloud
ZONE 如要瞭解支援的區域,請參閱 TPU 區域和區域文件。
工作人員 可存取基礎 TPU 的 TPU VM。

您可以執行下列指令,檢查裝置數量 (這裡顯示的輸出內容是使用 v5litepod-16 切片產生)。這段程式碼會檢查 JAX 是否能看到 Cloud TPU TensorCore,並執行基本作業,藉此測試所有項目是否已正確安裝:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

畫面中會顯示如下的輸出結果:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device_count() 會顯示指定切片中的晶片總數。 jax.local_device_count() 表示這個區塊中單一 VM 可存取的晶片數量。

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

畫面中會顯示如下的輸出結果:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

如要開始使用 JAX 訓練 v5e,請參閱本文中的 JAX 教學課程

PyTorch 設定

請注意,v5e 僅支援 PJRT 執行階段,且 PyTorch 2.1 以上版本會將 PJRT 設為所有 TPU 版本的預設執行階段。

本節說明如何開始在 v5e 上使用 PJRT,搭配 PyTorch/XLA 和所有工作站的指令。

安裝依附元件

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip install mkl mkl-include
      pip install tf-nightly tb-nightly tbp-nightly
      pip install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

PYTORCH_VERSION 替換為您要使用的 PyTorch 版本。PYTORCH_VERSION 用於指定 PyTorch/XLA 的相同版本。建議使用 2.6.0 版。

如要進一步瞭解 PyTorch 和 PyTorch/XLA 版本,請參閱「PyTorch - Get Started」和「PyTorch/XLA releases」。

如要進一步瞭解如何安裝 PyTorch/XLA,請參閱 PyTorch/XLA 安裝說明

如果安裝 torchtorch_xlatorchvision 的 Wheel 時發生錯誤 (例如 pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end or semicolon (after name and no valid version specifier) torch==nightly+20230222),請使用下列指令降級版本:

pip3 install setuptools==62.1.0

使用 PJRT 執行指令碼

unset LD_PRELOAD

以下範例使用 Python 指令碼,在 v5e VM 上執行計算:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      unset LD_PRELOAD
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

這麼做會產生類似以下內容的輸出結果:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')

請參閱本文中的 PyTorch 教學課程,瞭解如何使用 PyTorch 開始訓練 v5e 模型。

在工作階段結束時,刪除 TPU 和排入佇列的資源。如要刪除排入佇列的資源,請分 2 個步驟刪除切片和排入佇列的資源:

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

這兩個步驟也可用於移除處於 FAILED 狀態的已排隊資源要求。

JAX/FLAX 範例

以下各節提供範例,說明如何在 TPU v5e 上訓練 JAX 和 FLAX 模型。

在 v5e 上訓練 ImageNet

本教學課程說明如何使用偽輸入資料,在 v5e 上訓練 ImageNet。如要使用真實資料,請參閱 GitHub 上的 README 檔案

設定

  1. 建立環境變數:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-8
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本
    SERVICE_ACCOUNT 服務帳戶的電子郵件地址。如要查看服務帳戶,請前往 Google Cloud 控制台的「Service Accounts」(服務帳戶) 頁面。

    例如: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID 佇列資源要求的使用者指派文字 ID。

  2. 建立 TPU 資源:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    排入佇列的資源處於 ACTIVE 狀態時,您就能透過 SSH 連線至 TPU VM:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    當 QueuedResource 處於 ACTIVE 狀態時,輸出內容會類似於下列內容:

     state: ACTIVE
    
  3. 安裝最新版 JAX 和 jaxlib:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. 複製 ImageNet 模型並安裝對應需求:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
    
  5. 如要產生虛假資料,模型需要資料集維度的相關資訊。這項資訊可從 ImageNet 資料集的中繼資料收集而來:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
    

訓練模型

完成所有前述步驟後,即可訓練模型。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"

刪除 TPU 和排入佇列的資源

在工作階段結束時刪除 TPU 和排入佇列的資源。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Hugging Face FLAX 模型

以 FLAX 實作的 Hugging Face 模型可在 Cloud TPU v5e 上直接運作。本節提供執行熱門模型的操作說明。

在 Imagenette 上訓練 ViT

本教學課程說明如何使用 Fast AI Imagenette 資料集,在 Cloud TPU v5e 上訓練 HuggingFace 的 Vision Transformer (ViT) 模型。

ViT 模型是第一個成功在 ImageNet 上訓練 Transformer 編碼器,且與卷積網路相比,結果相當出色的模型。詳情請參閱下列資源:

設定

  1. 建立環境變數:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本
    SERVICE_ACCOUNT 服務帳戶的電子郵件地址。如要查看服務帳戶,請前往 Google Cloud 控制台的「Service Accounts」(服務帳戶) 頁面。

    例如: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID 佇列資源要求的使用者指派文字 ID。

  2. 建立 TPU 資源:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    排入佇列的資源處於 ACTIVE 狀態時,您就能透過 SSH 連線至 TPU VM:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    當佇列資源處於 ACTIVE 狀態時,輸出內容會類似下列範例:

     state: ACTIVE
    
  3. 安裝 JAX 及其程式庫:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. 下載 Hugging Face 存放區 並安裝必要條件:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
    
  5. 下載 Imagenette 資料集:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
    

訓練模型

使用預先對應的 4 GB 緩衝區訓練模型。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'

刪除 TPU 和排入佇列的資源

在工作階段結束時刪除 TPU 和排入佇列的資源。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

ViT 基準測試結果

我們是在 v5litepod-4、v5litepod-16 和 v5litepod-64 上執行訓練指令碼。下表顯示不同加速器類型的輸送量。

加速器類型 v5litepod-4 v5litepod-16 v5litepod-64
訓練週期 3 3 3
全域批次大小 32 128 512
處理量 (每秒樣本數) 263.40 429.34 470.71

訓練寶可夢的 Diffusion 模型

本教學課程說明如何使用 Cloud TPU v5e 上的 Pokémon 資料集,訓練 HuggingFace 的 Stable Diffusion 模型。

Stable Diffusion 模型是潛在的文字轉圖像模型,可根據任何文字輸入生成逼真的圖像。詳情請參閱下列資源:

設定

  1. 為儲存空間 bucket 名稱設定環境變數:

    export GCS_BUCKET_NAME=your_bucket_name
  2. 設定模型輸出內容的儲存空間 bucket:

    gcloud storage buckets create gs://GCS_BUCKET_NAME \
        --project=your_project \
        --location=us-west1
  3. 建立環境變數:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west1-c
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本
    SERVICE_ACCOUNT 服務帳戶的電子郵件地址。如要查看服務帳戶,請前往 Google Cloud 控制台的「Service Accounts」(服務帳戶) 頁面。

    例如: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID 佇列資源要求的使用者指派文字 ID。

  4. 建立 TPU 資源:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    排入佇列的資源處於 ACTIVE 狀態時,您就能透過 SSH 連線至 TPU VM:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    當佇列資源處於 ACTIVE 狀態時,輸出內容會類似下列範例:

     state: ACTIVE
    
  5. 安裝 JAX 和其程式庫。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  6. 下載 HuggingFace 存放區,並安裝必要條件。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
         --project=${PROJECT_ID} \
         --zone=${ZONE} \
         --worker=all \
         --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
    

訓練模型

使用預先對應的 4 GB 緩衝區訓練模型。

gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
    git clone https://github.com/google/maxdiffusion
    cd maxdiffusion
    pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    pip3 install -r requirements.txt
    pip3 install .
    pip3 install gcsfs
    export LIBTPU_INIT_ARGS=''
    python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
    jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
    per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
    output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"

清除所用資源

在工作階段結束時,刪除 TPU、排入佇列的資源和 Cloud Storage 值區。

  1. 刪除 TPU:

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  2. 刪除排入佇列的資源:

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project=${PROJECT_ID} \
        --zone=${ZONE} \
        --quiet
    
  3. 刪除 Cloud Storage bucket:

    gcloud storage rm -r gs://${GCS_BUCKET_NAME}
    

擴散基準化結果

訓練指令碼是在 v5litepod-4、v5litepod-16 和 v5litepod-64 上執行。下表顯示處理量。

加速器類型 v5litepod-4 v5litepod-16 v5litepod-64
訓練步數 1500 1500 1500
全域批次大小 32 64 128
處理量 (每秒樣本數) 36.53 43.71 49.36

PyTorch/XLA

以下各節提供範例,說明如何在 TPU v5e 上訓練 PyTorch/XLA 模型。

使用 PJRT 執行階段訓練 ResNet

PyTorch/XLA 會從 PyTorch 2.0 以上版本的 XRT 遷移至 PjRt。以下是更新版的操作說明,可協助您為 PyTorch/XLA 訓練工作負載設定 v5e。

設定
  1. 建立環境變數:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本
    SERVICE_ACCOUNT 服務帳戶的電子郵件地址。如要查看服務帳戶,請前往 Google Cloud 控制台的「Service Accounts」(服務帳戶) 頁面。

    例如: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID 佇列資源要求的使用者指派文字 ID。

  2. 建立 TPU 資源:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    QueuedResource 處於 ACTIVE 狀態時,您就能透過 SSH 連線至 TPU VM:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    當佇列資源處於 ACTIVE 狀態時,輸出內容會類似下列範例:

     state: ACTIVE
    
  3. 安裝 Torch/XLA 專屬依附元件

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project=${PROJECT_ID} \
      --zone=${ZONE} \
      --worker=all \
      --command='
         sudo apt-get update -y
         sudo apt-get install libomp5 -y
         pip3 install mkl mkl-include
         pip3 install tf-nightly tb-nightly tbp-nightly
         pip3 install numpy
         sudo apt-get install libopenblas-dev -y
         pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

    PYTORCH_VERSION 替換為您要使用的 PyTorch 版本。PYTORCH_VERSION 用於指定 PyTorch/XLA 的相同版本。建議使用 2.6.0 版。

    如要進一步瞭解 PyTorch 和 PyTorch/XLA 版本,請參閱「PyTorch - Get Started」和「PyTorch/XLA releases」。

    如要進一步瞭解如何安裝 PyTorch/XLA,請參閱 PyTorch/XLA 安裝說明

訓練 ResNet 模型
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      date
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export XLA_USE_BF16=1
      export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      git clone https://github.com/pytorch/xla.git
      cd xla/
      git checkout release-r2.6
      python3 test/test_train_mp_imagenet.py --model=resnet50  --fake_data --num_epochs=1 —num_workers=16  --log_steps=300 --batch_size=64 --profile'

刪除 TPU 和排入佇列的資源

在工作階段結束時刪除 TPU 和排入佇列的資源。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
基準結果

下表顯示基準處理量。

加速器類型 處理量 (每秒範例數)
v5litepod-4 4240 ex/s
v5litepod-16 10,810 ex/s
v5litepod-64 46,154 ex/s

在 v5e 上訓練 ViT

本教學課程將說明如何使用 HuggingFace 存放區,在 PyTorch/XLA 上對 cifar10 資料集執行 VIT v5e。

設定

  1. 建立環境變數:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-west4-a
    export ACCELERATOR_TYPE=v5litepod-16
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your-service-account
    export QUEUED_RESOURCE_ID=your-queued-resource-id

    環境變數說明

    變數 說明
    PROJECT_ID 您的 Google Cloud 專案 ID。使用現有專案或建立新專案
    TPU_NAME TPU 的名稱。
    ZONE 要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域
    ACCELERATOR_TYPE 加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。
    RUNTIME_VERSION Cloud TPU 軟體版本
    SERVICE_ACCOUNT 服務帳戶的電子郵件地址。如要查看服務帳戶,請前往 Google Cloud 控制台的「Service Accounts」(服務帳戶) 頁面。

    例如: tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_ID 佇列資源要求的使用者指派文字 ID。

  2. 建立 TPU 資源:

    gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --service-account=${SERVICE_ACCOUNT}
    

    QueuedResource 處於 ACTIVE 狀態時,您就能透過 SSH 連線至 TPU VM:

     gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    當佇列資源處於 ACTIVE 狀態時,輸出內容會類似如下:

     state: ACTIVE
    
  3. 安裝 PyTorch/XLA 依附元件

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip3 install mkl mkl-include
      pip3 install tf-nightly tb-nightly tbp-nightly
      pip3 install numpy
      sudo apt-get install libopenblas-dev -y
      pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
      pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

    PYTORCH_VERSION 替換為您要使用的 PyTorch 版本。PYTORCH_VERSION 用於指定 PyTorch/XLA 的相同版本。建議使用 2.6.0 版。

    如要進一步瞭解 PyTorch 和 PyTorch/XLA 版本,請參閱「PyTorch - Get Started」和「PyTorch/XLA releases」。

    如要進一步瞭解如何安裝 PyTorch/XLA,請參閱 PyTorch/XLA 安裝說明

  4. 下載 HuggingFace 存放區,並安裝必要條件。

       gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="
          git clone https://github.com/suexu1025/transformers.git vittransformers; \
          cd vittransformers; \
          pip3 install .; \
          pip3 install datasets; \
          wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
    

訓練模型

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export TF_CPP_MIN_LOG_LEVEL=0
      export XLA_USE_BF16=1
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      cd vittransformers
      python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
      --remove_unused_columns=False \
      --label_names=pixel_values \
      --mask_ratio=0.75 \
      --norm_pix_loss=True \
      --do_train=true \
      --do_eval=true \
      --base_learning_rate=1.5e-4 \
      --lr_scheduler_type=cosine \
      --weight_decay=0.05 \
      --num_train_epochs=3 \
      --warmup_ratio=0.05 \
      --per_device_train_batch_size=8 \
      --per_device_eval_batch_size=8 \
      --logging_strategy=steps \
      --logging_steps=30 \
      --evaluation_strategy=epoch \
      --save_strategy=epoch \
      --load_best_model_at_end=True \
      --save_total_limit=3 \
      --seed=1337 \
      --output_dir=MAE \
      --overwrite_output_dir=true \
      --logging_dir=./tensorboard-metrics \
      --tpu_metrics_debug=true'

刪除 TPU 和排入佇列的資源

在工作階段結束時刪除 TPU 和排入佇列的資源。

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

基準結果

下表列出不同加速器類型的基準輸送量。

v5litepod-4 v5litepod-16 v5litepod-64
訓練週期 3 3 3
全域批次大小 32 128 512
處理量 (每秒樣本數) 201 657 2,844