This tutorial shows you how to train the ResNet-50 model on a Cloud TPU device with PyTorch. You can apply the same pattern to other TPU-optimised image classification models that use PyTorch and the ImageNet dataset.
The model in this tutorial is based on Deep Residual Learning for Image Recognition, which first introduces the residual network (ResNet) architecture. The tutorial uses the 50-layer variant, ResNet-50, and demonstrates training the model using PyTorch/XLA.
Objectives
- Prepare the dataset.
- Run the training job.
- Verify the output results.
Costs
In this document, you use the following billable components of Google Cloud:
- Compute Engine
- Cloud TPU
To generate a cost estimate based on your projected usage,
use the pricing calculator.
Before you begin
Before starting this tutorial, check that your Google Cloud project is correctly set up.
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project.
This walkthrough uses billable components of Google Cloud. Check the Cloud TPU pricing page to estimate your costs. Be sure to clean up resources you created when you've finished with them to avoid unnecessary charges.
Create a TPU VM
Open a Cloud Shell window.
Create a TPU VM
gcloud compute tpus tpu-vm create your-tpu-name \ --accelerator-type=v3-8 \ --version=tpu-ubuntu2204-base \ --zone=us-central1-a \ --project=your-project
Connect to your TPU VM using SSH:
gcloud compute tpus tpu-vm ssh your-tpu-name --zone=us-central1-a
Install PyTorch/XLA on your TPU VM:
(vm)$ pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
Clone the PyTorch/XLA GitHub repo
(vm)$ git clone --depth=1 --branch r2.5 https://github.com/pytorch/xla.git
Run the training script with fake data
(vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
If you are able to train the model using fake data, you can try training on
real data, such as ImageNet. For instructions on downloading ImageNet, see
Downloading ImageNet. In the training script command,
the --datadir
flag specifies the location of the dataset on which to train.
The following command assumes the ImageNet dataset is located in ~/imagenet
.
(vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --datadir=~/imagenet --batch_size=256 --num_epochs=1
Clean up
To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.
Disconnect from the TPU VM:
(vm) $ exit
Your prompt should now be
username@projectname
, showing you are in the Cloud Shell.Delete your TPU VM.
$ gcloud compute tpus tpu-vm delete your-tpu-name \ --zone=us-central1-a
What's next