Frequently Asked Questions - Cloud TPU

This document contains a list of frequently asked questions about Cloud TPUs. It is broken up into sections:

  1. Framework independent FAQs - questions about using Cloud TPUs regardless of what ML framework you are using.
  2. JAX FAQS - questions about using Cloud TPUs with JAX.
  3. PyTorch FAQs - questions about using Cloud TPUs with PyTorch.

If you are looking for information about using Cloud TPUs with TensorFlow, see Run TensorFlow models on Cloud TPU.

Framework independent FAQs

How do I check which process is using the TPU on a Cloud TPU VM?

Run sudo lsof -w /dev/accel* on the Cloud TPU VM to print the process ID and other information about the process using the TPU.

How do I add a persistent disk volume to a Cloud TPU VM?

For more information, see Add a persistent disk to a TPU VM

What storage options are supported/recommended for training with TPU VM?

For more information, see Cloud TPU storage options.

JAX FAQs

How do I know if the TPU is being used by my program?

There are a few ways to double check JAX is using the TPU:

  1. Use the jax.devices() function. For example:

    assert jax.devices()[0].platform == 'tpu'
    
  2. Profile your program and verify the profile contains TPU operations. For more information, see Profiling JAX programs

For more information, see JAX FAQ

Pytorch FAQs

How do I know if the TPU is being used by my program?

You can run following python commands:

>>> import torch_xla.core.xla_model as xm
>>> xm.get_xla_supported_devices(devkind="TPU")

And verify if you can see any TPU devices.