Frequently Asked Questions - Cloud TPU

This document contains a list of frequently asked questions about Cloud TPUs. It is broken up into sections:

  1. Framework independent FAQs - questions about using Cloud TPUs regardless of what ML framework you are using.
  2. JAX FAQS - questions about using Cloud TPUs with JAX.
  3. PyTorch FAQs - questions about using Cloud TPUs with PyTorch.

Framework independent FAQs

How do I check which process is using the TPU on a Cloud TPU VM?

Run sudo lsof -w /dev/accel* on the Cloud TPU VM to print the process ID and other information about the process using the TPU.

How do I add a persistent disk volume to a Cloud TPU VM?

For more information, see Add a persistent disk to a TPU VM

What storage options are supported/recommended for training with TPU VM?

For more information, see Cloud TPU storage options.

JAX FAQs

How do I know if the TPU is being used by my program?

There are a few ways to double check JAX is using the TPU:

  1. Use the jax.devices() function. For example:

    assert jax.devices()[0].platform == 'tpu'
    
  2. Profile your program and verify the profile contains TPU operations. For more information, see Profiling JAX programs

For more information, see JAX FAQ

Pytorch FAQs

How do I know if the TPU is being used by my program?

You can run following python commands:

>>> import torch_xla.core.xla_model as xm
>>> xm.get_xla_supported_devices(devkind="TPU")

And verify if you can see any TPU devices.