Cloud TPU v5e training
Cloud TPU v5e is Google Cloud's latest generation AI accelerator. With a smaller 256-chip footprint per Pod, a v5e is optimized to be the highest value product for transformer, text-to-image, and Convolutional Neural Network (CNN) training, fine-tuning, and serving. For more information about using Cloud TPU v5e for serving, see Inference using v5e.
For more information about Cloud TPU v5e TPU hardware and configurations, see TPU v5e.
Get started
The following sections describe how to get started using TPU v5e.
Request quota
You need quota to use TPU v5e for training. There are different quota types for on-demand TPUs, reserved TPUs, and TPU Spot VMs. There are separate quotas required if you're using your TPU v5e for inference. For more information about quotas, see Quotas. To request TPU v5e quota, contact Cloud Sales.
Create a Google Cloud account and project
You need a Google Cloud account and project to use Cloud TPU. For more information, see Set up a Cloud TPU environment.
Create a Cloud TPU
The best practice is to provision Cloud TPU v5es as queued
resources using the queued-resource create
command. For more information, see Manage queued
resources.
You can also use the Create Node API (gcloud compute tpus tpu-vm create
)
to provision Cloud TPU
v5es. For more information, see Manage TPU
resources.
For more information about available v5e configurations for training, see Cloud TPU v5e types for training.
Framework setup
This section describes the general setup process for custom model training using
JAX or PyTorch with TPU v5e. TensorFlow support is available in the
tpu-vm-tf-2.17.0-pjrt
and tpu-vm-tf-2.17.0-pod-pjrt
TPU
runtime versions.
For inference setup instructions, see v5e inference introduction.
Setup for JAX
If you have slice shapes greater than 8 chips, you will have multiple VMs in one
slice. In this case, you need to use the --worker=all
flag to run the
installation on all TPU VMs in a single step without using SSH to log into each
separately:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Variable | Description |
TPU_NAME | The user-assigned text ID of the TPU which is created when the queued resource request is allocated. |
PROJECT_ID | Google Cloud Project Name. Use an existing project or create a new one at Set up your Google Cloud project |
ZONE | See the TPU regions and zones document for the supported zones. |
worker | The TPU VM that has access to the underlying TPUs. |
You can run the following command to check number of devices (the outputs shown here were produced with a v5litepod-16 slice). This code tests that everything is installed correctly by checking that JAX sees the Cloud TPU TensorCores and can run basic operations:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'
The output will be similar to the following:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4
jax.device_count()
shows the total number of chips in the given slice.
jax.local_device_count()
indicates the count of chips accessible by a single
VM in this slice.
# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'
The output will be similar to the following:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
Try the JAX Tutorials in this document to get started with v5e training using JAX.
Setup for PyTorch
Note that v5e only supports the PJRT runtime and PyTorch 2.1+ will use PJRT as the default runtime for all TPU versions.
This section describes how to start using PJRT on v5e with PyTorch/XLA with commands for all workers.
Install dependencies
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
sudo apt-get update -y
sudo apt-get install libomp5 -y
pip3 install mkl mkl-include
pip3 install tf-nightly tb-nightly tbp-nightly
pip3 install numpy
sudo apt-get install libopenblas-dev -y
pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
If you get an error when installing the wheels for torch
, torch_xla
, or
torchvision
like
pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end
or semicolon (after name and no valid version specifier) torch==nightly+20230222
,
downgrade your version with this command:
pip3 install setuptools==62.1.0
Run a script with PJRT
unset LD_PRELOAD
The following is an example using a Python script to do a calculation on a v5e VM:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker all \
--command='
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
export PJRT_DEVICE=TPU_C_API
export PT_XLA_DEBUG=0
export USE_TORCH=ON
unset LD_PRELOAD
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'
This generates output similar to the following:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
Try the PyTorch Tutorials in this document to get started with v5e training using PyTorch.
Delete your TPU and queued resource at the end of your session. To delete a queued resource, delete the slice and then the queued resource in 2 steps:
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
These two steps can also be used to remove queued resource requests that are in
the FAILED
state.
JAX/FLAX examples
The following sections describe examples of how to train JAX and FLAX models on TPU v5e.
Train ImageNet on v5e
This tutorial describes how to train ImageNet on v5e using fake input data. If you want to use real data, refer to the README file on GitHub.
Set up
Create environment variables:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
You will be able to SSH to your TPU VM once your queued resource is in the
ACTIVE
state:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
When the QueuedResource is in the
ACTIVE
state, the output will be similar to the following:state: ACTIVE
Install newest version of JAX and jaxlib:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Clone the ImageNet model and install the corresponding requirements:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/google/flax.git && cd flax/examples/imagenet && pip install -r requirements.txt && pip install flax==0.7.4'
To generate fake data, the model needs information on the dimensions of the dataset. This can be gathered from the ImageNet dataset's metadata:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='mkdir -p $HOME/flax/.tfds/metadata/imagenet2012/5.1.0 && curl https://raw.githubusercontent.com/tensorflow/datasets/v4.4.0/tensorflow_datasets/testing/metadata/imagenet2012/5.1.0/dataset_info.json --output $HOME/flax/.tfds/metadata/imagenet2012/5.1.0/dataset_info.json'
Train the model
Once all the previous steps are done, you can train the model.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd flax/examples/imagenet && JAX_PLATFORMS=tpu python3 imagenet_fake_data_benchmark.py'
Delete the TPU and queued resource
Delete your TPU and queued resource at the end of your session.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Hugging Face FLAX Models
Hugging Face models implemented in FLAX work out of the box on Cloud TPU v5e. This section provides instructions for running popular models.
Train ViT on Imagenette
This tutorial shows you how to train the Vision Transformer (ViT) model from HuggingFace using the Fast AI Imagenette dataset on Cloud TPU v5e.
The ViT model was the first one that successfully trained a Transformer encoder on ImageNet with excellent results compared to convolutional networks. For more information, see the following resources:
Set up
Create environment variables:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your_queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
You will be able to SSH to your TPU VM once your queued resource is in state
ACTIVE
:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
When the queued resource is in the
ACTIVE
state, the output will be similar to the following:state: ACTIVE
Install JAX and its library:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Download Hugging Face repository and install requirements:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.17.0 && pip install -r examples/flax/vision/requirements.txt'
Download the Imagenette dataset:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
Train the model
Train the model with a pre-mapped buffer at 4GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'
Delete the TPU and queued resource
Delete your TPU and queued-resource at the end of your session.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
ViT benchmarking results
The training script was run on v5litepod-4, v5litepod-16, and v5litepod-64. The following table shows the throughputs with different accelerator types.
Accelerator type | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Epoch | 3 | 3 | 3 |
Global batch size | 32 | 128 | 512 |
Throughput (examples/sec) | 263.40 | 429.34 | 470.71 |
Train Diffusion on Pokémon
This tutorial shows you how to train the Stable Diffusion model from HuggingFace using the Pokémon dataset on Cloud TPU v5e.
The Stable Diffusion model is a latent text-to-image model that generates photo-realistic images from any text input. For more information, see the following resources:
Set up
Set up a storage bucket for your model output.
gcloud storage buckets create gs://your_bucket
--project=your_project
--location=us-west1
export GCS_BUCKET_NAME=your_bucketCreate environment variables
export GCS_BUCKET_NAME=your_bucket export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west1-c export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
Command flag descriptions
Variable Description GCS_BUCKET_NAME Displayed in Google Cloud console -> Cloud Storage -> Buckets PROJECT_ID Google Cloud Project Name. Use an existing project or create a new one at Set up your Google Cloud project ACCELERATOR_TYPE See the TPU versions page for your TPU version. ZONE See the TPU regions and zones document for the supported zones. RUNTIME_VERSION Use v2-alpha-tpuv5 for the RUNTIME_VERSION. SERVICE_ACCOUNT This is the address of your service account that you can find in Google Cloud console -> IAM -> Service Accounts. For example: tpu-service-account@myprojectID.iam.gserviceaccount.com TPU_NAME The user-assigned text ID of the TPU which is created when the queued resource request is allocated. QUEUED_RESOURCE_ID The user-assigned text ID of the queued resource request. See the Queued Resources document for information about queued resources. QUOTA_TYPE Can be reserved
orspot
. If neither of these are specified, the QUOTA_TYPE defaults toon-demand
. See quotas for information on the different types of quotas supported by Cloud TPU.VALID_UNTIL_DURATION The duration for which the request is valid. See Queued resources for information about the different valid durations. -
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
You will be able to SSH to your TPU VM once your queued resource is in the
ACTIVE
state:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
When the queued resource is in the
ACTIVE
state, the output will be similar to the following:state: ACTIVE
Install JAX and its library.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Download the HuggingFace repository and install requirements.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install tensorflow==2.17.0 clu && pip install -U -r examples/text_to_image/requirements_flax.txt'
Train the model
Train the model with a pre-mapped buffer at 4GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE} --project ${PROJECT_ID} --worker=all --command="
git clone https://github.com/google/maxdiffusion
cd maxdiffusion
git reset --hard 57629bcf4fa32fe5a57096b60b09f41f2fa5c35d # This identifies the GitHub commit to use.
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip3 install -r requirements.txt
pip3 install .
export LIBTPU_INIT_ARGS=""
python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml run_name=your_run base_output_directory=gs://${GCS_BUCKET_NAME}/ enable_profiler=False"
Delete the TPU and queued resource
Delete your TPU and queued resource at the end of your session.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Benchmarking results for diffusion
The training script ran on v5litepod-4, v5litepod-16, and v5litepod-64. The following table shows the throughputs.
Accelerator type | v5litepod-4 | v5litepod-16 | v5litepod-64 |
Train Step | 1500 | 1500 | 1500 |
Global batch size | 32 | 64 | 128 |
Throughput (examples/sec) | 36.53 | 43.71 | 49.36 |
Train GPT2 on the OSCAR dataset
This tutorial shows you how to train the GPT2 model from HuggingFace using the OSCAR dataset on Cloud TPU v5e.
The GPT2 is a transformer model pre-trained on raw texts without human labeling. It was trained to predict the next word in sentences. For more information, see the following resources:
Set up
Create environment variables:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
You will be able to SSH to your TPU VM once your queued resource is in the
ACTIVE
state:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
When the queued resource is in the
ACTIVE
state, the output will be similar to the following:state: ACTIVE
Install JAX and its library.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Download HuggingFace repository and install requirements.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install TensorFlow && pip install -r examples/flax/language-modeling/requirements.txt'
Download configs to train the model.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command='cd transformers/examples/flax/language-modeling && gcloud storage cp gs://cloud-tpu-tpuvm-artifacts/v5litepod-preview/jax/gpt . --recursive'
Train the model
Train the model with a pre-mapped buffer at 4GB.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='cd transformers/examples/flax/language-modeling && TPU_PREMAPPED_BUFFER_SIZE=4294967296 JAX_PLATFORMS=tpu python3 run_clm_flax.py --output_dir=./gpt --model_type=gpt2 --config_name=./gpt --tokenizer_name=./gpt --dataset_name=oscar --dataset_config_name=unshuffled_deduplicated_no --do_train --do_eval --block_size=512 --per_device_train_batch_size=4 --per_device_eval_batch_size=4 --learning_rate=5e-3 --warmup_steps=1000 --adam_beta1=0.9 --adam_beta2=0.98 --weight_decay=0.01 --overwrite_output_dir --num_train_epochs=3 --logging_steps=500 --eval_steps=2500'
Delete the TPU and queued resource
Delete your TPU and queued resource at the end of your session.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Benchmarking results for GPT2
The training script ran on v5litepod-4, v5litepod-16, and v5litepod-64. The following table shows the throughputs.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Epoch | 3 | 3 | 3 |
Global batch size | 64 | 64 | 64 |
Throughput (examples/sec) | 74.60 | 72.97 | 72.62 |
PyTorch/XLA
The following sections describe examples of how to train PyTorch/XLA models on TPU v5e.
Train ResNet using the PJRT runtime
PyTorch/XLA is migrating from XRT to PjRt from PyTorch 2.0+. Here are the updated instructions to set up v5e for PyTorch/XLA training workloads.
Set up
Create environment variables:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=tpu-name export QUEUED_RESOURCE_ID=queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --{QUOTA_TYPE}
You will be able to SSH to your TPU VM once your QueuedResource is in
ACTIVE
state:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
When the queued resource is in the
ACTIVE
state, the output will be similar to the following:state: ACTIVE
Install Torch/XLA specific dependencies
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
Train the ResNet model
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
date
export PJRT_DEVICE=TPU_C_API
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export XLA_USE_BF16=1
export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
git clone https://github.com/pytorch/xla.git
cd xla/
git reset --hard caf5168785c081cd7eb60b49fe4fffeb894c39d9
python3 test/test_train_mp_imagenet.py --model=resnet50 --fake_data --num_epochs=1 —num_workers=16 --log_steps=300 --batch_size=64 --profile'
Delete the TPU and queued resource
Delete your TPU and queued resource at the end of your session.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Benchmark result
The following table shows the benchmark throughputs.
Accelerator type | Throughput (examples/second) |
v5litepod-4 | 4240 ex/s |
v5litepod-16 | 10,810 ex/s |
v5litepod-64 | 46,154 ex/s |
Train GPT2 on v5e
This tutorial will cover how to run GPT2 on v5e using HuggingFace repository on PyTorch/XLA using the wikitext dataset.
Set up
Create environment variables:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
You will be able to SSH to your TPU VM once your QueuedResource is in
ACTIVE
state:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
When the queued resource is in the
ACTIVE
state, the output will be similar to the following:state: ACTIVE
Install PyTorch/XLA dependencies.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' sudo apt-get -y update sudo apt install -y libopenblas-base pip3 install torchvision pip3 uninstall -y torch pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
Download HuggingFace repository and install requirements.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' git clone https://github.com/pytorch/xla.git pip install --upgrade accelerate git clone https://github.com/huggingface/transformers.git cd transformers git checkout ebdb185befaa821304d461ed6aa20a17e4dc3aa2 pip install . git log -1 pip install datasets evaluate scikit-learn '
Download configs of the pre-trained model.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=' gcloud storage cp gs://cloud-tpu-tpuvm-artifacts/config/xl-ml-test/pytorch/gpt2/my_config_2.json transformers/examples/pytorch/language-modeling/ --recursive gcloud storage cp gs://cloud-tpu-tpuvm-artifacts/config/xl-ml-test/pytorch/gpt2/fsdp_config.json transformers/examples/pytorch/language-modeling/'
Train the model
Train the 2B model using a batch size of 16.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU_C_API
cd transformers/
export LD_LIBRARY_PATH=/usr/local/lib/
export PT_XLA_DEBUG=0
export USE_TORCH=ON
python3 examples/pytorch/xla_spawn.py \
--num_cores=4 \
examples/pytorch/language-modeling/run_clm.py \
--num_train_epochs=3 \
--dataset_name=wikitext \
--dataset_config_name=wikitext-2-raw-v1 \
--per_device_train_batch_size=16 \
--per_device_eval_batch_size=16 \
--do_train \
--do_eval \
--logging_dir=./tensorboard-metrics \
--cache_dir=./cache_dir \
--output_dir=/tmp/test-clm \
--overwrite_output_dir \
--cache_dir=/tmp \
--config_name=examples/pytorch/language-modeling/my_config_2.json \
--tokenizer_name=gpt2 \
--block_size=1024 \
--optim=adafactor \
--adafactor=true \
--save_strategy=no \
--logging_strategy=no \
--fsdp=full_shard \
--fsdp_config=examples/pytorch/language-modeling/fsdp_config.json'
Delete the TPU and queued resource
Delete your TPU and queued resource at the end of your session.
gcloud compute tpus tpu-vm delete ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--quiet
Benchmark result
The training script ran on v5litepod-4, v5litepod-16, and v5litepod-64. The following table shows the benchmark throughputs for different accelerator types.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Epoch | 3 | 3 | 3 |
config | 600M | 2B | 16B |
Global batch size | 64 | 128 | 256 |
Throughput (examples/sec) | 66 | 77 | 31 |
Train ViT on v5e
This tutorial will cover how to run VIT on v5e using the HuggingFace repository on PyTorch/XLA on the cifar10 dataset.
Set up
Create environment variables:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-west4-a export RUNTIME_VERSION=v2-alpha-tpuv5-lite export SERVICE_ACCOUNT=your_service_account export TPU_NAME=tpu-name export QUEUED_RESOURCE_ID=queued_resource_id export QUOTA_TYPE=quota_type export VALID_UNTIL_DURATION=1d
-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --valid-until-duration=${VALID_UNTIL_DURATION} \ --service-account=${SERVICE_ACCOUNT} \ --${QUOTA_TYPE}
You will be able to SSH to your TPU VM once your QueuedResource is in the
ACTIVE
state:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
When the queued resource is in the
ACTIVE
state, the output will be similar to the following:state: ACTIVE
Install PyTorch/XLA dependencies
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all --command=' sudo apt-get update -y sudo apt-get install libomp5 -y pip3 install mkl mkl-include pip3 install tf-nightly tb-nightly tbp-nightly pip3 install numpy sudo apt-get install libopenblas-dev -y pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
Download HuggingFace repository and install requirements.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --worker=all \ --command=" git clone https://github.com/suexu1025/transformers.git vittransformers; \ cd vittransformers; \ pip3 install .; \ pip3 install datasets; \ wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
Train the model
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='
export PJRT_DEVICE=TPU_C_API
export PT_XLA_DEBUG=0
export USE_TORCH=ON
export TF_CPP_MIN_LOG_LEVEL=0
export XLA_USE_BF16=1
export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
cd vittransformers
python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
--remove_unused_columns=False \
--label_names=pixel_values \
--mask_ratio=0.75 \
--norm_pix_loss=True \
--do_train=true \
--do_eval=true \
--base_learning_rate=1.5e-4 \
--lr_scheduler_type=cosine \
--weight_decay=0.05 \
--num_train_epochs=3 \
--warmup_ratio=0.05 \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--logging_strategy=steps \
--logging_steps=30 \
--evaluation_strategy=epoch \
--save_strategy=epoch \
--load_best_model_at_end=True \
--save_total_limit=3 \
--seed=1337 \
--output_dir=MAE \
--overwrite_output_dir=true \
--logging_dir=./tensorboard-metrics \
--tpu_metrics_debug=true'
Delete the TPU and queued resource
Delete your TPU and queued resource at the end of your session.
gcloud compute tpus tpu-vm delete ${TPU_NAME}
--project=${PROJECT_ID}
--zone=${ZONE}
--quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID}
--project=${PROJECT_ID}
--zone=${ZONE}
--quiet
Benchmark result
The following table shows the benchmark throughputs for different accelerator types.
v5litepod-4 | v5litepod-16 | v5litepod-64 | |
Epoch | 3 | 3 | 3 |
Global batch size | 32 | 128 | 512 |
Throughput (examples/sec) | 201 | 657 | 2,844 |
TensorFlow 2.x
The following sections describe examples of how to train TensorFlow 2.x models on TPU v5e.
Train Resnet on a single-host v5e
This tutorial describes how to train ImageNet on v5litepod-4
or v5litepod-8
using a fake dataset. If you want to use a different dataset, refer to
Preparing the dataset.
Set up
Create environment variables:
export PROJECT_ID=your-project-ID export ACCELERATOR_TYPE=v5litepod-4 export ZONE=us-east1-c export RUNTIME_VERSION=tpu-vm-tf-2.15.0-pjrt export TPU_NAME=your-tpu-name export QUEUED_RESOURCE_ID=your-queued-resource-id export QUOTA_TYPE=quota-type
ACCELERATOR_TYPE
can be eitherv5litepod-4
orv5litepod-8
.-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --${QUOTA_TYPE}
You will be able to SSH to your TPU VM once your queued resource is in the
ACTIVE
state. To check the state of your queued resource, use the following command:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Connect to your TPU using SSH
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Set some environment variables
export MODELS_REPO=/usr/share/tpu/models export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}" export MODEL_DIR=gcp-directory-to-store-model export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet export NEXT_PLUGGABLE_DEVICE_USE_C_API=true export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
Change to the models repository directory and install requirements.
cd ${MODELS_REPO} && git checkout r2.15.0 pip install -r official/requirements.txt
Train the model
Run the training script.
python3 official/vision/train.py \
--tpu=local \
--experiment=resnet_imagenet \
--mode=train_and_eval \
--config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
--model_dir=${MODEL_DIR} \
--params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"
Delete the TPU and queued resource
Delete your TPU
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Delete your queued resource request
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Train Resnet on a multi-host v5e
This tutorial describes how to train ImageNet on v5litepod-16
or larger using
a fake dataset. If you want to use a different dataset, see Preparing the dataset.
Create environment variables:
export PROJECT_ID=your_project_ID export ACCELERATOR_TYPE=v5litepod-16 export ZONE=us-east1-c export RUNTIME_VERSION=tpu-vm-tf-2.15.0-pod-pjrt export TPU_NAME=your_tpu_name export QUEUED_RESOURCE_ID=your-queued-resource-id export QUOTA_TYPE=quota-type
ACCELERATOR_TYPE
can be eitherv5litepod-16
or larger.-
gcloud compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --${QUOTA_TYPE}
You will be able to SSH to your TPU VM once your queued resource is in the
ACTIVE
state. To check the state of your queued resource, use the following command:gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Connect to your TPU (worker zero) using SSH
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE}
Set some environment variables
export MODELS_REPO=/usr/share/tpu/models export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}" export MODEL_DIR=gcp-directory-to-store-model export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet export TPU_LOAD_LIBRARY=0 export TPU_NAME=your_tpu_name
Change to the models repository directory and install requirements.
cd $MODELS_REPO && git checkout r2.15.0 pip install -r official/requirements.txt
Train the model
Run the training script.
python3 official/vision/train.py \
--tpu=${TPU_NAME} \
--experiment=resnet_imagenet \
--mode=train_and_eval \
--model_dir=${MODEL_DIR} \
--params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*, task.validation_data.input_path=${DATA_DIR}/validation*"
Delete the TPU and queued resource
Delete your TPU
gcloud compute tpus tpu-vm delete ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet
Delete your queued resource request
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --quiet