This tutorial shows you how to train the ResNet-50 model on a Cloud TPU device with PyTorch. You can apply the same pattern to other TPU-optimised image classification models that use PyTorch and the ImageNet dataset.
The model in this tutorial is based on Deep Residual Learning for Image Recognition, which first introduces the residual network (ResNet) architecture. The tutorial uses the 50-layer variant, ResNet-50, and demonstrates training the model using PyTorch/XLA.
Objectives
- Prepare the dataset.
- Run the training job.
- Verify the output results.
Costs
In this document, you use the following billable components of Google Cloud:
- Compute Engine
- Cloud TPU
To generate a cost estimate based on your projected usage,
use the pricing calculator.
Before you begin
Before starting this tutorial, check that your Google Cloud project is correctly set up.
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project. Learn how to check if billing is enabled on a project.
-
In the Google Cloud console, on the project selector page, select or create a Google Cloud project.
-
Make sure that billing is enabled for your Google Cloud project. Learn how to check if billing is enabled on a project.
This walkthrough uses billable components of Google Cloud. Check the Cloud TPU pricing page to estimate your costs. Be sure to clean up resources you create when you've finished with them to avoid unnecessary charges.
Create a TPU VM
Open a Cloud Shell window.
Create a TPU VM
gcloud compute tpus tpu-vm create your-tpu-name \ --accelerator-type=v4-8 \ --version=tpu-vm-v4-pt-2.0 \ --zone=us-central2-b \ --project=your-project
Connect to your TPU VM using SSH:
gcloud compute tpus tpu-vm ssh your-tpu-name --zone=us-central2-b
Clone the PyTorch/XLA github repo
(vm)$ git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
Run the training script with fake data
There are two PyTorch/XLA runtime options: PJRT and XRT. We recommend you use PJRT unless you know you need to use XRT. To learn more about the different runtime configurations, see the PJRT runtime documentation.
PJRT
(vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
XRT (legacy)
Configure XRT
(vm) $ export TPU_NUM_DEVICES=4 (vm) $ export XRT_TPU_CONFIG="localservice;0;localhost:51011"
Run the training script
(vm) $ python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1
If you are able to train the model using fake data, you can try training on
real data, such as ImageNet. For instructions on downloading ImageNet, see
Downloading ImageNet. In the training script command,
the --datadir
flag specifies the location of the dataset on which to train.
The following command assumes the ImageNet dataset is located in ~/imagenet
.
PJRT
(vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --datadir=~/imagenet --batch_size=256 --num_epochs=1
XRT (legacy)
Configure XRT
<pre class="lang-sh prettyprint">
<span class="no-select">(vm) $ </span>export TPU_NUM_DEVICES=4
<span class="no-select">(vm) $ </span>export XRT_TPU_CONFIG="localservice;0;localhost:51011"
</pre>
Run the training script
<pre class="lang-sh prettyprint">
<span class="no-select">(vm) $ </span>python3 xla/test/test_train_mp_imagenet.py --datadir=<var>~/imagenet</var> --batch_size=256 --num_epochs=1
</pre>
Clean up
To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.
Disconnect from the TPU VM:
(vm) $ exit
Your prompt should now be
username@projectname
, showing you are in the Cloud Shell.Delete your TPU VM.
$ gcloud compute tpus tpu-vm delete resnet50-tutorial \ --zone=us-central2-b
What's next
Try the PyTorch colabs:
- Getting Started with PyTorch on Cloud TPUs
- Training MNIST on TPUs
- Training ResNet18 on TPUs with Cifar10 dataset
- Inference with Pretrained ResNet50 Model
- Fast Neural Style Transfer
- MultiCore Training AlexNet on Fashion MNIST
- Single Core Training AlexNet on Fashion MNIST