使用 TPU v5e 訓練模型
每個 Pod 的晶片數量較少 (256 個),因此 TPU v5e 經過最佳化調整,可成為 Transformer、文字轉圖片和卷積類神經網路 (CNN) 訓練、微調和服務的高價值產品。如要進一步瞭解如何使用 Cloud TPU v5e 進行服務,請參閱「使用 v5e 進行推論」。
如要進一步瞭解 Cloud TPU v5e TPU 硬體和設定,請參閱「TPU v5e」。
開始使用
以下各節說明如何開始使用 TPU v5e。
要求配額
如要使用 TPU v5e 進行訓練,您需要配額。隨選 TPU、預留 TPU 和 TPU Spot VM 的配額類型不同。如果您使用 TPU v5e 進行推論,則需要不同的配額。如要進一步瞭解配額,請參閱配額。如要申請 TPU v5e 配額,請與 Cloud 銷售團隊聯絡。
建立 Google Cloud 帳戶和專案
如要使用 Cloud TPU,您需要 Google Cloud 帳戶和專案。詳情請參閱「設定 Cloud TPU 環境」。
建立 Cloud TPU
最佳做法是使用 queued-resource create
指令,將 Cloud TPU v5e 佈建為佇列資源。詳情請參閱「管理已加入佇列的資源」。
您也可以使用 Create Node API (gcloud compute tpus tpu-vm create
) 佈建 Cloud TPU v5e。詳情請參閱「管理 TPU 資源」。
如要進一步瞭解可用的 v5e 訓練設定,請參閱「Cloud TPU v5e 訓練類型」。
設定架構
本節說明使用 JAX 或 PyTorch 和 TPU v5e 訓練自訂模型的一般設定程序。
如需推論設定操作說明,請參閱 v5e 推論簡介。
定義一些環境變數:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id
設定 JAX
如果切片形狀大於 8 個晶片,則一個切片中會有數個 VM。在這種情況下,您需要使用 --worker=all
旗標,在單一步驟中於所有 TPU VM 上執行安裝作業,不必使用 SSH 分別登入各個 VM:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
指令旗標說明
變數 | 說明 |
TPU_NAME | TPU 的使用者指派文字 ID,會在排入佇列的資源要求獲得分配時建立。 |
PROJECT_ID | Google Cloud 專案名稱。使用現有專案或建立新專案 前往 設定專案 Google Cloud |
ZONE | 如要瞭解支援的區域,請參閱 TPU 區域和區域文件。 |
工作人員 | 可存取基礎 TPU 的 TPU VM。 |
您可以執行下列指令,檢查裝置數量 (這裡顯示的輸出內容是使用 v5litepod-16 切片產生)。這段程式碼會檢查 JAX 是否能看到 Cloud TPU TensorCore,並執行基本作業,藉此測試所有項目是否已正確安裝:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
畫面中會顯示如下的輸出結果:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4
jax.device_count()
會顯示指定切片中的晶片總數。
jax.local_device_count()
表示這個區塊中單一 VM 可存取的晶片數量。
# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
畫面中會顯示如下的輸出結果:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
如要開始使用 JAX 訓練 v5e,請參閱本文中的 JAX 教學課程。
PyTorch 設定
請注意,v5e 僅支援 PJRT 執行階段,且 PyTorch 2.1 以上版本會將 PJRT 設為所有 TPU 版本的預設執行階段。
本節說明如何開始在 v5e 上使用 PJRT,搭配 PyTorch/XLA 和所有工作站的指令。
安裝依附元件
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip install mkl mkl-include pip install tf-nightly tb-nightly tbp-nightly pip install numpy sudo apt-get install libopenblas-dev -y pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
將 PYTORCH_VERSION
替換為您要使用的 PyTorch 版本。PYTORCH_VERSION
用於指定 PyTorch/XLA 的相同版本。建議使用 2.6.0 版。
如要進一步瞭解 PyTorch 和 PyTorch/XLA 版本,請參閱「PyTorch - Get Started」和「PyTorch/XLA releases」。
如要進一步瞭解如何安裝 PyTorch/XLA,請參閱 PyTorch/XLA 安裝說明。
如果安裝 torch
、torch_xla
或 torchvision
的 Wheel 時發生錯誤 (例如 pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end
or semicolon (after name and no valid version specifier) torch==nightly+20230222
),請使用下列指令降級版本:
pip3 install setuptools==62.1.0
使用 PJRT 執行指令碼
unset LD_PRELOAD
以下範例使用 Python 指令碼,在 v5e VM 上執行計算:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
unset LD_PRELOAD
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
這麼做會產生類似以下內容的輸出結果:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
請參閱本文中的 PyTorch 教學課程,瞭解如何使用 PyTorch 開始訓練 v5e 模型。
在工作階段結束時,刪除 TPU 和排入佇列的資源。如要刪除排入佇列的資源,請分 2 個步驟刪除切片和排入佇列的資源:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
這兩個步驟也可用於移除處於 FAILED
狀態的已排隊資源要求。
JAX/FLAX 範例
以下各節提供範例,說明如何在 TPU v5e 上訓練 JAX 和 FLAX 模型。
在 v5e 上訓練 ImageNet
本教學課程說明如何使用偽輸入資料,在 v5e 上訓練 ImageNet。如要使用真實資料,請參閱 GitHub 上的 README 檔案。
設定
建立環境變數:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-8 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
環境變數說明
變數 說明 PROJECT_ID
您的 Google Cloud 專案 ID。使用現有專案或建立新專案。 TPU_NAME
TPU 的名稱。 ZONE
要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域。 ACCELERATOR_TYPE
加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。 RUNTIME_VERSION
Cloud TPU 軟體版本。 SERVICE_ACCOUNT
服務帳戶的電子郵件地址。如要查看服務帳戶,請前往 Google Cloud 控制台的「Service Accounts」(服務帳戶) 頁面。 例如:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
佇列資源要求的使用者指派文字 ID。 -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
排入佇列的資源處於
ACTIVE
狀態時,您就能透過 SSH 連線至 TPU VM:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
當 QueuedResource 處於
ACTIVE
狀態時,輸出內容會類似於下列內容:state: ACTIVE
安裝最新版 JAX 和 jaxlib:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
複製 ImageNet 模型並安裝對應需求:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
如要產生虛假資料,模型需要資料集維度的相關資訊。這項資訊可從 ImageNet 資料集的中繼資料收集而來:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"
訓練模型
完成所有前述步驟後,即可訓練模型。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"
刪除 TPU 和排入佇列的資源
在工作階段結束時刪除 TPU 和排入佇列的資源。
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Hugging Face FLAX 模型
以 FLAX 實作的 Hugging Face 模型可在 Cloud TPU v5e 上直接運作。本節提供執行熱門模型的操作說明。
在 Imagenette 上訓練 ViT
本教學課程說明如何使用 Fast AI Imagenette 資料集,在 Cloud TPU v5e 上訓練 HuggingFace 的 Vision Transformer (ViT) 模型。
ViT 模型是第一個成功在 ImageNet 上訓練 Transformer 編碼器,且與卷積網路相比,結果相當出色的模型。詳情請參閱下列資源:
設定
建立環境變數:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
環境變數說明
變數 說明 PROJECT_ID
您的 Google Cloud 專案 ID。使用現有專案或建立新專案。 TPU_NAME
TPU 的名稱。 ZONE
要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域。 ACCELERATOR_TYPE
加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。 RUNTIME_VERSION
Cloud TPU 軟體版本。 SERVICE_ACCOUNT
服務帳戶的電子郵件地址。如要查看服務帳戶,請前往 Google Cloud 控制台的「Service Accounts」(服務帳戶) 頁面。 例如:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
佇列資源要求的使用者指派文字 ID。 -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
排入佇列的資源處於
ACTIVE
狀態時,您就能透過 SSH 連線至 TPU VM:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
當佇列資源處於
ACTIVE
狀態時,輸出內容會類似下列範例:state: ACTIVE
安裝 JAX 及其程式庫:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
下載 Hugging Face 存放區 並安裝必要條件:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
下載 Imagenette 資料集:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
訓練模型
使用預先對應的 4 GB 緩衝區訓練模型。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'
刪除 TPU 和排入佇列的資源
在工作階段結束時刪除 TPU 和排入佇列的資源。
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
ViT 基準測試結果
我們是在 v5litepod-4、v5litepod-16 和 v5litepod-64 上執行訓練指令碼。下表顯示不同加速器類型的輸送量。
加速器類型 | v5litepod-4 | v5litepod-16 | v5litepod-64 |
訓練週期 | 3 | 3 | 3 |
全域批次大小 | 32 | 128 | 512 |
處理量 (每秒樣本數) | 263.40 | 429.34 | 470.71 |
訓練寶可夢的 Diffusion 模型
本教學課程說明如何使用 Cloud TPU v5e 上的 Pokémon 資料集,訓練 HuggingFace 的 Stable Diffusion 模型。
Stable Diffusion 模型是潛在的文字轉圖像模型,可根據任何文字輸入生成逼真的圖像。詳情請參閱下列資源:
設定
為儲存空間 bucket 名稱設定環境變數:
export GCS_BUCKET_NAME=your_bucket_name
設定模型輸出內容的儲存空間 bucket:
gcloud storage buckets create gs://GCS_BUCKET_NAME \ --project=your_project \ --location=us-west1
建立環境變數:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west1-c export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
環境變數說明
變數 說明 PROJECT_ID
您的 Google Cloud 專案 ID。使用現有專案或建立新專案。 TPU_NAME
TPU 的名稱。 ZONE
要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域。 ACCELERATOR_TYPE
加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。 RUNTIME_VERSION
Cloud TPU 軟體版本。 SERVICE_ACCOUNT
服務帳戶的電子郵件地址。如要查看服務帳戶,請前往 Google Cloud 控制台的「Service Accounts」(服務帳戶) 頁面。 例如:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
佇列資源要求的使用者指派文字 ID。 -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
排入佇列的資源處於
ACTIVE
狀態時,您就能透過 SSH 連線至 TPU VM:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
當佇列資源處於
ACTIVE
狀態時,輸出內容會類似下列範例:state: ACTIVE
安裝 JAX 和其程式庫。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
下載 HuggingFace 存放區,並安裝必要條件。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'
訓練模型
使用預先對應的 4 GB 緩衝區訓練模型。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project=${PROJECT_ID} --worker=all --command="
git clone https://github.com/google/maxdiffusion
cd maxdiffusion
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
pip3 install gcsfs
export LIBTPU_INIT_ARGS=''
python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \
jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \
per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \
output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"
清除所用資源
在工作階段結束時,刪除 TPU、排入佇列的資源和 Cloud Storage 值區。
刪除 TPU:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
刪除排入佇列的資源:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
刪除 Cloud Storage bucket:
gcloud storage rm -r gs://${GCS_BUCKET_NAME}
擴散基準化結果
訓練指令碼是在 v5litepod-4、v5litepod-16 和 v5litepod-64 上執行。下表顯示處理量。
加速器類型 | v5litepod-4 | v5litepod-16 | v5litepod-64 |
訓練步數 | 1500 | 1500 | 1500 |
全域批次大小 | 32 | 64 | 128 |
處理量 (每秒樣本數) | 36.53 | 43.71 | 49.36 |
PyTorch/XLA
以下各節提供範例,說明如何在 TPU v5e 上訓練 PyTorch/XLA 模型。
使用 PJRT 執行階段訓練 ResNet
PyTorch/XLA 會從 PyTorch 2.0 以上版本的 XRT 遷移至 PjRt。以下是更新版的操作說明,可協助您為 PyTorch/XLA 訓練工作負載設定 v5e。
設定
建立環境變數:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
環境變數說明
變數 說明 PROJECT_ID
您的 Google Cloud 專案 ID。使用現有專案或建立新專案。 TPU_NAME
TPU 的名稱。 ZONE
要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域。 ACCELERATOR_TYPE
加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。 RUNTIME_VERSION
Cloud TPU 軟體版本。 SERVICE_ACCOUNT
服務帳戶的電子郵件地址。如要查看服務帳戶,請前往 Google Cloud 控制台的「Service Accounts」(服務帳戶) 頁面。 例如:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
佇列資源要求的使用者指派文字 ID。 -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
QueuedResource 處於
ACTIVE
狀態時,您就能透過 SSH 連線至 TPU VM:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
當佇列資源處於
ACTIVE
狀態時,輸出內容會類似下列範例:state: ACTIVE
安裝 Torch/XLA 專屬依附元件
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'
將
PYTORCH_VERSION
替換為您要使用的 PyTorch 版本。PYTORCH_VERSION
用於指定 PyTorch/XLA 的相同版本。建議使用 2.6.0 版。如要進一步瞭解 PyTorch 和 PyTorch/XLA 版本,請參閱「PyTorch - Get Started」和「PyTorch/XLA releases」。
如要進一步瞭解如何安裝 PyTorch/XLA,請參閱 PyTorch/XLA 安裝說明。
訓練 ResNet 模型
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
date
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export XLA_USE_BF16=1
export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
git clone https://github.com/pytorch/xla.git
cd xla/
git checkout release-r2.6
python3 test/test_train_mp_imagenet.py --model=resnet50 --fake_data --num_epochs=1 —num_workers=16 --log_steps=300 --batch_size=64 --profile'
刪除 TPU 和排入佇列的資源
在工作階段結束時刪除 TPU 和排入佇列的資源。
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
基準結果
下表顯示基準處理量。
加速器類型 | 處理量 (每秒範例數) |
v5litepod-4 | 4240 ex/s |
v5litepod-16 | 10,810 ex/s |
v5litepod-64 | 46,154 ex/s |
在 v5e 上訓練 ViT
本教學課程將說明如何使用 HuggingFace 存放區,在 PyTorch/XLA 上對 cifar10 資料集執行 VIT v5e。
設定
建立環境變數:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-west4-a export ACCELERATOR_TYPE=v5litepod-16 export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
環境變數說明
變數 說明 PROJECT_ID
您的 Google Cloud 專案 ID。使用現有專案或建立新專案。 TPU_NAME
TPU 的名稱。 ZONE
要建立 TPU VM 的區域。如要進一步瞭解支援的區域,請參閱 TPU 區域和區域。 ACCELERATOR_TYPE
加速器類型會指定您要建立的 Cloud TPU 版本和大小。如要進一步瞭解各個 TPU 版本支援的加速器類型,請參閱「TPU 版本」。 RUNTIME_VERSION
Cloud TPU 軟體版本。 SERVICE_ACCOUNT
服務帳戶的電子郵件地址。如要查看服務帳戶,請前往 Google Cloud 控制台的「Service Accounts」(服務帳戶) 頁面。 例如:
tpu-service-account@PROJECT_ID.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
佇列資源要求的使用者指派文字 ID。 -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
QueuedResource 處於
ACTIVE
狀態時,您就能透過 SSH 連線至 TPU VM:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
當佇列資源處於
ACTIVE
狀態時,輸出內容會類似如下:state: ACTIVE
安裝 PyTorch/XLA 依附元件
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
將
PYTORCH_VERSION
替換為您要使用的 PyTorch 版本。PYTORCH_VERSION
用於指定 PyTorch/XLA 的相同版本。建議使用 2.6.0 版。如要進一步瞭解 PyTorch 和 PyTorch/XLA 版本,請參閱「PyTorch - Get Started」和「PyTorch/XLA releases」。
如要進一步瞭解如何安裝 PyTorch/XLA,請參閱 PyTorch/XLA 安裝說明。
下載 HuggingFace 存放區,並安裝必要條件。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=" git clone https://github.com/suexu1025/transformers.git vittransformers; \ cd vittransformers; \ pip3 install .; \ pip3 install datasets; \ wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
訓練模型
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export TF_CPP_MIN_LOG_LEVEL=0
export XLA_USE_BF16=1
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
cd vittransformers
python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
--remove_unused_columns=False \
--label_names=pixel_values \
--mask_ratio=0.75 \
--norm_pix_loss=True \
--do_train=true \
--do_eval=true \
--base_learning_rate=1.5e-4 \
--lr_scheduler_type=cosine \
--weight_decay=0.05 \
--num_train_epochs=3 \
--warmup_ratio=0.05 \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--logging_strategy=steps \
--logging_steps=30 \
--evaluation_strategy=epoch \
--save_strategy=epoch \
--load_best_model_at_end=True \
--save_total_limit=3 \
--seed=1337 \
--output_dir=MAE \
--overwrite_output_dir=true \
--logging_dir=./tensorboard-metrics \
--tpu_metrics_debug=true'
刪除 TPU 和排入佇列的資源
在工作階段結束時刪除 TPU 和排入佇列的資源。
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
基準結果
下表列出不同加速器類型的基準輸送量。
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
訓練週期 | 3 | 3 | 3 |
全域批次大小 | 32 | 128 | 512 |
處理量 (每秒樣本數) | 201 | 657 | 2,844 |