Training Resnet50 on Cloud TPU with PyTorch


This tutorial shows you how to train the ResNet-50 model on a Cloud TPU device with PyTorch. You can apply the same pattern to other TPU-optimised image classification models that use PyTorch and the ImageNet dataset.

The model in this tutorial is based on Deep Residual Learning for Image Recognition, which first introduces the residual network (ResNet) architecture. The tutorial uses the 50-layer variant, ResNet-50, and demonstrates training the model using PyTorch/XLA.

Objectives

  • Prepare the dataset.
  • Run the training job.
  • Verify the output results.

Costs

In this document, you use the following billable components of Google Cloud:

  • Compute Engine
  • Cloud TPU

To generate a cost estimate based on your projected usage, use the pricing calculator. New Google Cloud users might be eligible for a free trial.

Before you begin

Before starting this tutorial, check that your Google Cloud project is correctly set up.

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Make sure that billing is enabled for your Google Cloud project.

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Make sure that billing is enabled for your Google Cloud project.

  6. This walkthrough uses billable components of Google Cloud. Check the Cloud TPU pricing page to estimate your costs. Be sure to clean up resources you created when you've finished with them to avoid unnecessary charges.

Create a TPU VM

  1. Open a Cloud Shell window.

    Open Cloud Shell

  2. Create a TPU VM

    gcloud compute tpus tpu-vm create your-tpu-name \
    --accelerator-type=v4-8 \
    --version=tpu-ubuntu2204-base \
    --zone=us-central2-b \
    --project=your-project
  3. Connect to your TPU VM using SSH:

    gcloud compute tpus tpu-vm ssh  your-tpu-name --zone=us-central2-b
  4. Install PyTorch/XLA on your TPU VM:

    (vm)$ pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html
  5. Clone the PyTorch/XLA GitHub repo

    (vm)$ git clone --depth=1 --branch r2.5 https://github.com/pytorch/xla.git
  6. Run the training script with fake data

    (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1

If you are able to train the model using fake data, you can try training on real data, such as ImageNet. For instructions on downloading ImageNet, see Downloading ImageNet. In the training script command, the --datadir flag specifies the location of the dataset on which to train. The following command assumes the ImageNet dataset is located in ~/imagenet.

   (vm) $ PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py  --datadir=~/imagenet --batch_size=256 --num_epochs=1
   

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

  1. Disconnect from the TPU VM:

    (vm) $ exit

    Your prompt should now be username@projectname, showing you are in the Cloud Shell.

  2. Delete your TPU VM.

    $ gcloud compute tpus tpu-vm delete resnet50-tutorial \
       --zone=us-central2-b

What's next