Cloud TPU v5e training

Cloud TPU v5e is Google Cloud's latest generation AI accelerator. With a smaller 256-chip footprint per Pod, a v5e is optimized to be the highest value product for transformer, text-to-image, and Convolutional Neural Network (CNN) training, fine-tuning, and serving.

Cloud TPU v5e concepts, system architecture, and configurations

If you are new to Cloud TPUs, check out the TPU documentation home.

General Cloud TPU concepts (for example, slices, hosts, chips, and TensorCores), and Cloud TPU system architecture are described in the Cloud TPU System Architecture page.

Each Cloud TPU version requires specific accelerator types for training and inference. These accelerator types are described in the v5e configurations section.

Inference

Inference is the process of using a trained model to make predictions on new data. It is used by the serving process.

Slices

A slice represents a collection of chips all located inside the same Pod connected by high-speed inter chip interconnects (ICI). v5e has 2D slice shapes. See the table in the v5e configurations section for supported slice shapes.

Chip shape and chip topology also refer to slice shapes.

Serving

Serving is the process of deploying a trained machine learning model to a production environment where it can be used to make predictions or decisions. Latency and service-level availability are important for serving.

Single host versus multi host

A host is a physical computer (CPU) that runs VMs. A host can run multiple VMs at once.

Slices using fewer than 8 chips use at most one host. Slices greater than 8 chips have access to more than a single host and can run distributed training using multiple hosts. See the TPU System Architecture page for details on slices and chips.

v5e supports multi-host training and multi-host inference (using SAX).

TPU VM

A virtual machine running Linux that has access to the underlying TPUs. For v5e TPUs, each TPU VM has direct access to 1, 4, or 8 chips depending on the user-specified accelerator type. A TPU VM is also known as a worker.

Worker

See TPU VM.

Get started

For information about v5e TPU hardware, see System Architecture.

Securing capacity

Contact Cloud Sales to start using Cloud TPU v5e for your AI workloads.

Prepare a Google Cloud Project

  1. Sign in to your Google Account. If you haven't already, sign up for a new account.
  2. In the Google Cloud console, select or create a Google Cloud project from the project selector page.
  3. Billing setup is required for all Google Cloud usage so make sure billing is enabled for your project.

  4. Install gcloud alpha components.

  5. If you are a TPU user reusing existing gcloud alpha components, update these to ensure that relevant commands and flags are supported:

    gcloud components update
    
  6. Enable the TPU API through the following gcloud command in the Cloud Shell. (You may also enable it from the Google Cloud console.)

    gcloud services enable tpu.googleapis.com
    
  7. Enable the TPU service account.

    Service accounts allow the Cloud TPU service to access other Google Cloud services. A user-managed service account is a recommended Google Cloud practice. Follow these guides to create and grant roles. The following roles are necessary:

    • TPU Admin
    • Storage Admin: Needed for accessing Cloud Storage
    • Logs Writer: Needed for writing logs with the Logging API
    • Monitoring Metric Writer: Needed for writing metrics to Cloud Monitoring
  8. Configure the project and zone.

    Your project ID is the name of your project shown on the Cloud console.

    export PROJECT_ID=your-project-id
    export ZONE=us-west4-a
    
    gcloud compute tpus tpu-vm service-identity create --zone=${ZONE}
    
    gcloud auth login
    gcloud config set project ${PROJECT_ID}
    gcloud config set compute/zone ${ZONE}
    

Provision the Cloud TPU environment

The best practice is to provision Cloud TPU v5es as queued resources using the queued-resource create command. However, you can also use the Create Node API (gcloud alpha compute tpus tpu-vm create) to provision Cloud TPU v5es.

Create environment variables

Set necessary environment variables for TPU creation.

Replace the variables (in red) in the following list with values you will use for your training or inference job.

export PROJECT_ID=your_project_ID
export ACCELERATOR_TYPE=v5litepod-16
export ZONE=us-west4-a
export RUNTIME_VERSION=v2-alpha-tpuv5-lite
export SERVICE_ACCOUNT=your_service_account
export TPU_NAME=your_tpu_name
export QUEUED_RESOURCE_ID=your_queued_resource_id
export QUOTA_TYPE=quota_type
export VALID_UNTIL_DURATION=1d
Variable Description
PROJECT_ID Google Cloud Project Name
ACCELERATOR_TYPE See the Accelerator Types section for supported accelerator types.
ZONE All capacity is in us-west4-a.
RUNTIME_VERSION Use v2-alpha-tpuv5-lite.
SERVICE_ACCOUNT This is the address of your service account that you can find in Google Cloud console -> IAM -> Service Accounts. For example: tpu-service-account@myprojectID.iam.gserviceaccount.com
TPU_NAME The user-assigned text ID of the TPU which is created when the queued resource request is allocated.
QUEUED_RESOURCE_ID The user-assigned text ID of the queued resource request. See the queued resources document for information about queued resources.
QUOTA_TYPE Can be reserved or best-effort. If neither of these are specified, the QUOTA_TYPE defaults to on-demand. See quotas for information on the different types of quotas supported by Cloud TPU.
VALID_UNTIL_DURATION The duration for which the request is valid. See queued resources for information about the different valid durations.

Create a TPU resource

gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
   --node-id=${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --accelerator-type=${ACCELERATOR_TYPE} \
   --runtime-version=${RUNTIME_VERSION} \
   --valid-until-duration=${VALID_UNTIL_DURATION} \
   --service-account=${SERVICE_ACCOUNT} \
   --${QUOTA_TYPE}

If the queued resource is created successfully, the state within the response field will be either WAITING_FOR_RESOURCES or FAILED. If the queued resource is in the WAITING_FOR_RESOURCES state, that means that the queued resource has passed preliminary validation and is awaiting capacity. Once capacity is available the request will transition to PROVISIONING. Being in WAITING_FOR_RESOURCES state does not mean it is guaranteed that you will get the quota allocated and it may take some time to change from WAITING_FOR_RESOURCES status to ACTIVE status. If the queued resource is in the FAILED state, the failure reason will be in the output. The request will expire if a request is not filled in the --valid-until-duration, and the state becomes "FAILED".

You will be able to access your TPU VM using SSH once your queued resource is in the ACTIVE state.

Use the [list](/tpu/docs/managing-tpus-tpu-vm) or [describe](/tpu/docs/managing-tpus-tpu-vm) commands to query the status of your queued resource.

gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE}

The state represents the status of a queued resource. The states are defined as:

State Description
WAITING_FOR_RESOURCES The queued resource create command has been received and will begin provisioning, as soon as capacity is available.
PROVISIONING The TPU slices are being provisioned.
ACTIVE All TPUs are provisioned, and are ready to use. If a startup script is given, it will start executing on all TPUs, when the queued resource state transitions to ACTIVE.
FAILED The slices couldn't be provisioned.
SUSPENDING One or more slices are being deleted.
SUSPENDED All underlying slices are deleted but the queued resource stays intact, until explicitly deleted. Right now, a suspended queued resource cannot be resumed and should be deleted.
DELETING The queued resource is being deleted.

Connect to the TPU VM using SSH

The following section describes how you can install binaries on each TPU VM in your TPU slice and run code. In this context, a TPU VM is also known as a worker.

See the VM Types section to determine how many VMs your slice will have.

To install the binaries or run code, connect to your TPU VM using the tpu-vm ssh command.

gcloud compute tpus tpu-vm ssh ${TPU_NAME}

To access a specific TPU VM with SSH, use the --worker flag which follows a 0-based index:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} --worker=1

If you have slice shapes greater than 8 chips, you will have multiple VMs in one slice. In this case, use the --worker=all flag to run the installation on all TPU VMs without having to connect to each one separately. For example:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Manage

All commands you can use to manage your TPU VMs are described in Managing TPUs.

Framework Setup

This section describes the general setup process for custom model training using JAX or PyTorch with TPU v5e. TensorFlow support is available in the tpu-vm-tf-2.15.0-pjrt and tpu-vm-tf-2.15.0-pod-pjrt TPU runtime versions.

For inference setup instructions, see v5e inference introduction.

Setup for JAX

If you have slice shapes greater than 8 chips, you will have multiple VMs in one slice. In this case, you need to use the --worker=all flag to run the installation on all TPU VMs in a single step without using SSH to log into each separately:

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

You can run the following command to check number of devices (the outputs shown here were produced with a v5litepod-16 slice). This code tests that everything is installed correctly by checking that JAX sees the Cloud TPU TensorCores and can run basic operations:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

The output will be similar to the following:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16
4
16
4
16
4
16
4

jax.device\_count() shows the total number of chips in the given slice. jax.local\_device\_count() indicates the count of chips accessible by a single VM in this slice.

# Check the number of chips in the given slice by summing the count of chips
# from all VMs through the
# jax.local_device_count() API call.
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

The output will be similar to the following:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]
[16. 16. 16. 16.]

Try the JAX Tutorials in this document to get started with v5e training using JAX.

Setup for PyTorch

Note that v5e only supports the PJRT runtime and PyTorch 2.1+ will use PJRT as the default runtime for all TPU versions.

This section describes how to start using PJRT on v5e with PyTorch/XLA with commands for all workers.

Install dependencies

gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      sudo apt-get update -y
      sudo apt-get install libomp5 -y
      pip3 install mkl mkl-include
      pip3 install tf-nightly tb-nightly tbp-nightly
      pip3 install numpy
      sudo apt-get install libopenblas-dev -y
      pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
      pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'

If you failed to install wheels for torch/torch_xla/torchvision and see an error like pkg_resources.extern.packaging.requirements.InvalidRequirement: Expected end or semicolon (after name and no valid version specifier) torch==nightly+20230222, downgrade your version with this command:

pip3 install setuptools==62.1.0

Run a script with PJRT:

unset LD_PRELOAD

The following is an example using a Python script to do a calculation on a v5e VM:

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker all \
   --command='
      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/
      export PJRT_DEVICE=TPU_C_API
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      unset LD_PRELOAD
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

This generates output similar to the following:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')
xla:0
tensor([[ 1.8611, -0.3114, -2.4208],
[-1.0731, 0.3422, 3.1445],
[ 0.5743, 0.2379, 1.1105]], device='xla:0')

Try the PyTorch Tutorials in this document to get started with v5e training using PyTorch.

Delete your TPU and queued resource at the end of your session. To delete a queued resource, delete the slice and then the queued resource in 2 steps:

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

These two steps can also be used to remove queued resource requests that are in the FAILED state.

Monitor and Profile

Cloud TPU v5e supports monitoring and profiling using the same methods as previous generations of Cloud TPU. You can read Profile your model with Cloud TPU tools to learn more about profiling and Monitoring Cloud TPU VMs to learn more about monitoring.

JAX/FLAX examples

Train ImageNet on v5e

This tutorial describes how to train ImageNet on v5e using fake input data. If you want to use real data, refer to the README file on GitHub.

Set up

  1. Create environment variables:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your_queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
    
  2. Create a TPU resource:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --${QUOTA_TYPE}
    

    You will be able to SSH to your TPU VM once your queued resource is in the ACTIVE state:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    When the QueuedResource is in the ACTIVE state, the output will be similar to the following:

     state: ACTIVE
    
  3. Install newest version of JAX and jaxlib:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Clone the ImageNet model and install the corresponding requirements:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/google/flax.git && cd flax/examples/imagenet && pip install -r requirements.txt && pip install flax==0.7.4'
    
  5. To generate fake data, the model needs information on the dimensions of the dataset. This can be gathered from the ImageNet dataset's metadata:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT_ID} --zone=${ZONE} --worker=all --command='mkdir -p $HOME/flax/.tfds/metadata/imagenet2012/5.1.0 && curl https://raw.githubusercontent.com/tensorflow/datasets/v4.4.0/tensorflow_datasets/testing/metadata/imagenet2012/5.1.0/dataset_info.json --output $HOME/flax/.tfds/metadata/imagenet2012/5.1.0/dataset_info.json'
    

Train the model

Once all the previous steps are done, you can train the model.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd flax/examples/imagenet && JAX_PLATFORMS=tpu python3 imagenet_fake_data_benchmark.py'

Delete the TPU and queued resource

Delete your TPU and queued resource at the end of your session.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Hugging Face FLAX Models

Hugging Face models implemented in FLAX work out of the box on Cloud TPU v5e. This section provides instructions for running popular models.

Train ViT on Imagenette

This tutorial shows you how to train the Vision Transformer (ViT) model from HuggingFace using the Fast AI imagenette dataset on Cloud TPU v5e.

The ViT model was the first one that successfully trained a Transformer encoder on ImageNet with excellent results compared to convolutional networks. For more information, see the following resources:

Set up

  1. Create environment variables:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your_queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
    
  2. Create a TPU resource:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --${QUOTA_TYPE}
    

    You will be able to SSH to your TPU VM once your queued resource is in state ACTIVE:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    When the queued resource is in the ACTIVE state, the output will be similar to the following:

     state: ACTIVE
    
  3. Install JAX and its library:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Download Hugging Face repository and install requirements:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.16.1 && pip install -r examples/flax/vision/requirements.txt'
    
  5. Download the Imagenette dataset:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'
    

Train the model

Train the model with a pre-mapped buffer at 4GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'

Delete the TPU and queued-resource

Delete your TPU and queued-resource at the end of your session.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

ViT benchmarking results

The training script was run on v5litepod-4, v5litepod-16, and v5litepod-64. The following table shows the throughputs with different accelerator types.

Accelerator type v5litepod-4 v5litepod-16 v5litepod-64
Epoch 3 3 3
Global batch size 32 128 512
Throughput (examples/sec) 263.40 429.34 470.71

Train Diffusion on Pokémon

This tutorial shows you how to train the Stable Diffusion model from HuggingFace using the Pokémon dataset on Cloud TPU v5e.

The Stable Diffusion model is a latent text-to-image model that generates photo-realistic images from any text input. For more information, see the following resources:

Set up

  1. Create environment variables:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
    
  2. Create a TPU resource:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --${QUOTA_TYPE}
    

    You will be able to SSH to your TPU VM once your queued resource is in the ACTIVE state:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    When the queued resource is in the ACTIVE state, the output will be similar to the following:

     state: ACTIVE
    
  3. Install JAX and its library.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Download the HuggingFace repository and install requirements.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install tensorflow==2.16.1 clu && pip install -U -r examples/text_to_image/requirements_flax.txt'
    

Train the model

Train the model with a pre-mapped buffer at 4GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd diffusers/examples/text_to_image && JAX_PLATFORMS=tpu,cpu python3 train_text_to_image_flax.py --pretrained_model_name_or_path=duongna/stable-diffusion-v1-4-flax --dataset_name=lambdalabs/pokemon-blip-captions --resolution=128 --center_crop --random_flip --train_batch_size=4 --mixed_precision=fp16 --max_train_steps=1500 --learning_rate=1e-05 --max_grad_norm=1 --output_dir=sd-pokemon-model'

Delete the TPU and queued resources

Delete your TPU and queued resource at the end of your session.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Benchmarking results for diffusion

The training script ran on v5litepod-4, v5litepod-16, and v5litepod-64. The following table shows the throughputs.

Accelerator type v5litepod-4 v5litepod-16 v5litepod-64
Train Step 1500 1500 1500
Global batch size 32 64 128
Throughput (examples/sec) 36.53 43.71 49.36

Train GPT2 on the OSCAR dataset

This tutorial shows you how to train the GPT2 model from HuggingFace using the OSCAR dataset on Cloud TPU v5e.

The GPT2 is a transformer model pre-trained on raw texts without human labeling. It was trained to predict the next word in sentences. For more information, see the following resources:

Set up

  1. Create environment variables:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
    
  2. Create a TPU resource:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --${QUOTA_TYPE}
    

    You will be able to SSH to your TPU VM once your queued resource is in the ACTIVE state:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    When the queued resource is in the ACTIVE state, the output will be similar to the following:

     state: ACTIVE
    
  3. Install JAX and its library.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Download HuggingFace repository and install requirements.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow && pip install -r examples/flax/language-modeling/requirements.txt'
    
  5. Download configs to train the model.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='cd transformers/examples/flax/language-modeling && gsutil cp -r gs://cloud-tpu-tpuvm-artifacts/v5litepod-preview/jax/gpt .'
    

Train the model

Train the model with a pre-mapped buffer at 4GB.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='cd transformers/examples/flax/language-modeling && TPU_PREMAPPED_BUFFER_SIZE=4294967296 JAX_PLATFORMS=tpu python3 run_clm_flax.py --output_dir=./gpt --model_type=gpt2 --config_name=./gpt --tokenizer_name=./gpt --dataset_name=oscar --dataset_config_name=unshuffled_deduplicated_no --do_train --do_eval --block_size=512 --per_device_train_batch_size=4 --per_device_eval_batch_size=4 --learning_rate=5e-3 --warmup_steps=1000 --adam_beta1=0.9 --adam_beta2=0.98 --weight_decay=0.01 --overwrite_output_dir --num_train_epochs=3 --logging_steps=500 --eval_steps=2500'

Delete the TPU and queued-resource

Delete your TPU and queued-resource at the end of your session.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Benchmarking results for GPT2

The training script ran on v5litepod-4, v5litepod-16, and v5litepod-64. The following table shows the throughputs.

v5litepod-4 v5litepod-16 v5litepod-64
Epoch 3 3 3
Global batch size 64 64 64
Throughput (examples/sec) 74.60 72.97 72.62

PyTorch/XLA

Train ResNet using the PJRT runtime

PyTorch/XLA is migrating from XRT to PjRt from PyTorch 2.0+. Here are the updated instructions to setup v5e for PyTorch/XLA training workloads.

Setup
  1. Create environment variables:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=tpu-name
    export QUEUED_RESOURCE_ID=queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
    
  2. Create a TPU resource:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --{QUOTA_TYPE}
    

    You will be able to SSH to your TPU VM once your QueuedResource is in ACTIVE state:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    When the queued resource is in the ACTIVE state, the output will be similar to the following:

     state: ACTIVE
    
  3. Install Torch/XLA specific dependencies

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='
          sudo apt-get update -y
          sudo apt-get install libomp5 -y
          pip3 install mkl mkl-include
          pip3 install tf-nightly tb-nightly tbp-nightly
          pip3 install numpy
          sudo apt-get install libopenblas-dev -y
          pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
          pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
    
Train the ResNet model
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      date
      export PJRT_DEVICE=TPU_C_API
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export XLA_USE_BF16=1
      export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      git clone https://github.com/pytorch/xla.git
      cd xla/
      git reset --hard caf5168785c081cd7eb60b49fe4fffeb894c39d9
      python3 test/test_train_mp_imagenet.py --model=resnet50  --fake_data --num_epochs=1 —num_workers=16  --log_steps=300 --batch_size=64 --profile'

Delete the TPU and queued-resource

Delete your TPU and queued-resource at the end of your session.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet
Benchmark Result

The following table shows the benchmark throughputs.

Accelerator type Throughput (examples/second)
v5litepod-4 4240 ex/s
v5litepod-16 10,810 ex/s
v5litepod-64 46,154 ex/s

Train GPT2 on v5e

This tutorial will cover how to run GPT2 on v5e using HuggingFace repository on PyTorch/XLA using the wikitext dataset.

Setup

  1. Create environment variables:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
    
  2. Create a TPU resource:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --${QUOTA_TYPE}
    

    You will be able to SSH to your TPU VM once your QueuedResource is in ACTIVE state:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    When the queued resource is in the ACTIVE state, the output will be similar to the following:

    state: ACTIVE
    
  3. Install torch/xla dependencies.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='
          sudo apt-get -y update
          sudo apt install -y libopenblas-base
          pip3 install torchvision
          pip3 uninstall -y torch
          pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
          pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
    
  4. Download HuggingFace repo and install requirements.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='
          git clone https://github.com/pytorch/xla.git
          pip install --upgrade accelerate
          git clone https://github.com/huggingface/transformers.git
          cd transformers
          git checkout ebdb185befaa821304d461ed6aa20a17e4dc3aa2
          pip install .
          git log -1
          pip install datasets evaluate scikit-learn
          '
    
  5. Download configs of the pre-trained model.

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command='
          gsutil cp -r gs://cloud-tpu-tpuvm-artifacts/config/xl-ml-test/pytorch/gpt2/my_config_2.json transformers/examples/pytorch/language-modeling/
          gsutil cp gs://cloud-tpu-tpuvm-artifacts/config/xl-ml-test/pytorch/gpt2/fsdp_config.json transformers/examples/pytorch/language-modeling/'
    

Train the model

Train the 2B model using a batch size of 16.

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU_C_API
      cd transformers/
      export LD_LIBRARY_PATH=/usr/local/lib/
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      python3 examples/pytorch/xla_spawn.py \
         --num_cores=4 \
         examples/pytorch/language-modeling/run_clm.py \
         --num_train_epochs=3 \
         --dataset_name=wikitext \
         --dataset_config_name=wikitext-2-raw-v1 \
         --per_device_train_batch_size=16 \
         --per_device_eval_batch_size=16 \
         --do_train \
         --do_eval \
         --logging_dir=./tensorboard-metrics \
         --cache_dir=./cache_dir \
         --output_dir=/tmp/test-clm \
         --overwrite_output_dir \
         --cache_dir=/tmp \
         --config_name=examples/pytorch/language-modeling/my_config_2.json \
         --tokenizer_name=gpt2 \
         --block_size=1024 \
         --optim=adafactor \
         --adafactor=true \
         --save_strategy=no \
         --logging_strategy=no \
         --fsdp=full_shard \
         --fsdp_config=examples/pytorch/language-modeling/fsdp_config.json'

Delete the TPU and queued-resource

Delete your TPU and queued-resource at the end of your session.

gcloud compute tpus tpu-vm delete ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --quiet

Benchmark Result

The training script ran on v5litepod-4, v5litepod-16, and v5litepod-64. The following table shows the benchmark throughputs for different accelerator types.

v5litepod-4 v5litepod-16 v5litepod-64
Epoch 3 3 3
config 600M 2B 16B
Global batch size 64 128 256
Throughput (examples/sec) 66 77 31

Train ViT on v5e

This tutorial will cover how to run VIT on v5e using the HuggingFace repository on PyTorch/XLA on the cifar10 dataset.

Setup

  1. Create environment variables:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-west4-a
    export RUNTIME_VERSION=v2-alpha-tpuv5-lite
    export SERVICE_ACCOUNT=your_service_account
    export TPU_NAME=tpu-name
    export QUEUED_RESOURCE_ID=queued_resource_id
    export QUOTA_TYPE=quota_type
    export VALID_UNTIL_DURATION=1d
    
  2. Create a TPU resource:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --valid-until-duration=${VALID_UNTIL_DURATION} \
       --service-account=${SERVICE_ACCOUNT} \
       --${QUOTA_TYPE}
    

    You will be able to SSH to your TPU VM once your QueuedResource is in the ACTIVE state:

     gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    

    When the queued resource is in the ACTIVE state, the output will be similar to the following:

     state: ACTIVE
    
  3. Install torch/xla dependencies

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all
       --command='
          sudo apt-get update -y
          sudo apt-get install libomp5 -y
          pip3 install mkl mkl-include
          pip3 install tf-nightly tb-nightly tbp-nightly
          pip3 install numpy
          sudo apt-get install libopenblas-dev -y
          pip3 install torch~=2.1.0 torchvision torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
          pip3 install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html'
    
  4. Download HuggingFace repo and install requirements.

       gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --worker=all \
       --command="
          git clone https://github.com/suexu1025/transformers.git vittransformers; \
          cd vittransformers; \
          pip3 install .; \
          pip3 install datasets; \
          wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"
    

Train the model

gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone=${ZONE} \
   --worker=all \
   --command='
      export PJRT_DEVICE=TPU_C_API
      export PT_XLA_DEBUG=0
      export USE_TORCH=ON
      export TF_CPP_MIN_LOG_LEVEL=0
      export XLA_USE_BF16=1
      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH
      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so
      cd vittransformers
      python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \
      --remove_unused_columns=False \
      --label_names=pixel_values \
      --mask_ratio=0.75 \
      --norm_pix_loss=True \
      --do_train=true \
      --do_eval=true \
      --base_learning_rate=1.5e-4 \
      --lr_scheduler_type=cosine \
      --weight_decay=0.05 \
      --num_train_epochs=3 \
      --warmup_ratio=0.05 \
      --per_device_train_batch_size=8 \
      --per_device_eval_batch_size=8 \
      --logging_strategy=steps \
      --logging_steps=30 \
      --evaluation_strategy=epoch \
      --save_strategy=epoch \
      --load_best_model_at_end=True \
      --save_total_limit=3 \
      --seed=1337 \
      --output_dir=MAE \
      --overwrite_output_dir=true \
      --logging_dir=./tensorboard-metrics \
      --tpu_metrics_debug=true'

Delete the TPU and queued-resource

Delete your TPU and queued-resource at the end of your session.

gcloud compute tpus tpu-vm delete ${TPU_NAME}
   --project=${PROJECT_ID}
   --zone=${ZONE}
   --quiet

gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID}
   --project=${PROJECT_ID}
   --zone=${ZONE}
   --quiet

Benchmark Result

The following table shows the benchmark throughputs for different accelerator types.

v5litepod-4 v5litepod-16 v5litepod-64
Epoch 3 3 3
Global batch size 32 128 512
Throughput (examples/sec) 201 657 2,844

TensorFlow 2.x

Train Resnet on a single host v5e

This tutorial describes how to train ImageNet on v5litepod-4 or v5litepod-8 using a fake dataset. If you want to use a different dataset, refer to Preparing the dataset.

Set up

  1. Create environment variables:

    export PROJECT_ID=your-project-ID
    export ACCELERATOR_TYPE=v5litepod-4
    export ZONE=us-east1-c
    export RUNTIME_VERSION=tpu-vm-tf-2.15.0-pjrt
    export TPU_NAME=your-tpu-name
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export QUOTA_TYPE=quota-type
    

    ACCELERATOR_TYPE can be either v5litepod-4 or v5litepod-8. /

  2. Create a TPU resource:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --${QUOTA_TYPE}
    

    You will be able to SSH to your TPU VM once your queued resource is in the ACTIVE state. To check the state of your queued resource, use the following command:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    
  3. Connect to your TPU using SSH

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    
  4. Set some environment variables

    export MODELS_REPO=/usr/share/tpu/models
    export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}"
    export MODEL_DIR=gcp-directory-to-store-model
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export NEXT_PLUGGABLE_DEVICE_USE_C_API=true
    export TF_PLUGGABLE_DEVICE_LIBRARY_PATH=/lib/libtpu.so
    
  5. Change to the models repository directory and install requirements.

    cd ${MODELS_REPO} && git checkout r2.15.0
    pip install -r official/requirements.txt
    

Train the model

Run the training script.

python3 official/vision/train.py \
   --tpu=local \
   --experiment=resnet_imagenet \
   --mode=train_and_eval \
   --config_file=official/vision/configs/experiments/image_classification/imagenet_resnet50_tpu.yaml \
   --model_dir=${MODEL_DIR} \
   --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*,task.validation_data.input_path=${DATA_DIR}/validation*,task.train_data.global_batch_size=2048,task.validation_data.global_batch_size=2048,trainer.train_steps=100"

Delete the TPU and queued-resource

  1. Delete your TPU

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --quiet
    
  2. Delete your queued resource request

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --quiet
    

Train Resnet on a multi host v5e

This tutorial describes how to train ImageNet on v5litepod-16 or larger using a fake dataset. If you want to use a different dataset, see Preparing the dataset.

  1. Create environment variables:

    export PROJECT_ID=your_project_ID
    export ACCELERATOR_TYPE=v5litepod-16
    export ZONE=us-east1-c
    export RUNTIME_VERSION=tpu-vm-tf-2.15.0-pod-pjrt
    export TPU_NAME=your_tpu_name
    export QUEUED_RESOURCE_ID=your-queued-resource-id
    export QUOTA_TYPE=quota-type
    

    ACCELERATOR_TYPE can be either v5litepod-16 or larger.

  2. Create a TPU resource:

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
       --node-id=${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --runtime-version=${RUNTIME_VERSION} \
       --${QUOTA_TYPE}
    

    You will be able to SSH to your TPU VM once your queued resource is in the ACTIVE state. To check the state of your queued resource, use the following command:

    gcloud compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    
  3. Connect to your TPU (worker zero) using SSH

    gcloud compute tpus tpu-vm ssh ${TPU_NAME}  \
       --project=${PROJECT_ID} \
       --zone=${ZONE}
    
  4. Set some environment variables

    export MODELS_REPO=/usr/share/tpu/models
    export PYTHONPATH="${MODELS_REPO}:${PYTHONPATH}"
    export MODEL_DIR=gcp-directory-to-store-model
    export DATA_DIR=gs://cloud-tpu-test-datasets/fake_imagenet
    export TPU_LOAD_LIBRARY=0
    export TPU_NAME=your_tpu_name
    
  5. Change to the models repository directory and install requirements.

     cd $MODELS_REPO && git checkout r2.15.0
     pip install -r official/requirements.txt
    

Train the model

Run the training script.

python3 official/vision/train.py \
   --tpu=${TPU_NAME} \
   --experiment=resnet_imagenet \
   --mode=train_and_eval \
   --model_dir=${MODEL_DIR} \
   --params_override="runtime.distribution_strategy=tpu,task.train_data.input_path=${DATA_DIR}/train*, task.validation_data.input_path=${DATA_DIR}/validation*"

Delete the TPU and queued-resource

  1. Delete your TPU

    gcloud compute tpus tpu-vm delete ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --quiet
    
  2. Delete your queued resource request

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
       --project=${PROJECT_ID} \
       --zone=${ZONE} \
       --quiet