Tensor Processing Units (TPUs) are ML accelerators designed by Google. Cloud TPU makes TPUs available as a scalable Google Cloud resource. You can run machine learning workloads on Cloud TPUs using machine learning frameworks such as TensorFlow, Pytorch, and JAX.
The Matrix-multiply unit (MXU) is composed of 128 x 128 multiply/accumulators in a systolic array. The MXUs provide the bulk of the compute power in a TensorCore (TPU core). Each MXU is capable of performing 16K multiply-accumulate operations in each cycle. All multiplies take bfloat16 inputs, but all accumulations are performed in FP32 number format.
The VPU is used for general computation such as activations, softmax, and so on. The scalar unit is used for control flow, calculating memory address, and other maintenance operations.
The exact layout of a TPU depends on the TPU version that you use. Architectural details and performance characteristics of TPU v2 and v3 are available in A Domain Specific Supercomputer for Training Deep Neural Networks.
Each v4 TPU chip contains two TensorCores. Each TensorCore has four MXUs, a vector unit, and a scalar unit. The following table shows the key specifications and their values for a v4 TPU Pod.
|Key specifications||v4 Pod values|
|Peak compute per chip||275 teraflops (bf16 or int8)|
|HBM2 capacity and bandwidth||32 GiB, 1200 GBps|
|Measured min/mean/max power||90/170/192 W|
|TPU Pod size||4096 chips|
|Interconnect topology||3D torus|
|Peak compute per Pod||1.1 exaflops (bf16 or int8)|
|All-reduce bandwidth per Pod||1.1 PB/s|
|Bisection bandwidth per Pod||24 TB/s|
The following diagram illustrates a TPU v4 chip.
Each v3 TPU chip contains two TensorCores. Each TensorCore has two MXUs, a vector unit, and a scalar unit. The following table shows the key specifications and their values for a v3 TPU Pod.
|Key specifications||v3 Pod values|
|Peak compute per chip||123 teraflops (bf16)|
|HBM2 capacity and bandwidth||32 GiB, 900 GBps|
|Measured min/mean/max power||123/220/262 W|
|TPU Pod size||1024 chips|
|Interconnect topology||2D torus|
|Peak compute per Pod||126 petaflops (bf16)|
|All-reduce bandwidth per Pod||340 TB/s|
|Bisection bandwidth per Pod||6.4 TB/s|
The following diagram illustrates a TPU v3 chip.
The smallest TPU v2 configuration contains four TPU chips and 16 GiB of HBM. Each TPU chip contains two TensorCores. Each TensorCore has an MXU, a vector unit, and a scalar unit. The following diagram illustrates a TPU v2 chip.
Cloud TPU provides the following TPU configurations:
- A single TPU device
- A TPU Pod - a group of TPU devices connected by high-speed interconnects
- A TPU slice - a subdivision of a TPU Pod
Performance benefits of TPU v4 over v3
- v4 TPU chips have a unified 32-GiB HBM memory space across the entire chip, enabling better coordination between the two on-chip TensorCores.
- Improved HBM performance using latest memory standards and speeds.
- Improved DMA performance profile with built-in support for high-performance striding at 512B granularities.
- Twice the number of MXUs and a higher clock rate delivering 275 max TFLOPS.
- 2x transposition and permutation bandwidth.
- Load-store memory access model for Common Memory (Cmem).
- Faster MXU weight loading bandwidth and 8-bit mode support to allow lower batch sizes and improved inference latency.
Six interconnect links per chip to enable network topologies that have smaller network diameters.
- x16 PCIE gen3 interface to host (direct connect).
- Improved security model.
- Improved energy efficiency.
Performance benefits of TPU v3 over v2
The increased FLOPS per core and memory capacity in TPU v3 configurations can improve the performance of your models in the following ways:
TPU v3 configurations provide significant performance benefits per core for compute-bound models. Memory-bound models on TPU v2 configurations might not achieve this same performance improvement if they are also memory-bound on TPU v3 configurations.
In cases where data does not fit into memory on TPU v2 configurations, TPU v3 can provide improved performance and reduced recomputation of intermediate values (rematerialization).
TPU v3 configurations can run new models with batch sizes that did not fit on TPU v2 configurations. For example, TPU v3 might allow deeper ResNets and larger images with RetinaNet.
Models that are nearly input-bound ("infeed") on TPU v2 because training steps are waiting for input might also be input-bound with Cloud TPU v3. The pipeline performance guide can help you resolve infeed issues.
TPU v4 configurations
A TPU v4 Pod is composed of 4096 chips interconnected with reconfigurable high
speed links. TPU v4's flexible networking lets you connect the chips in a
same sized slice in multiple ways. You configure v4 TPUs using AcceleratorConfig.
AcceleratorConfig enables you to specify slice size in terms of chips through
TPU v4 topology is specified using a 3-tuple which describes the number, shape, and interconnections between TPU chips. The following illustrations show some common TPU v4 topologies.
Larger slices can be built from one or more 4x4x4 "cubes" of chips.
TPU slices of a given number of chips can be configured in different ways. For example, a TPU slice with the AcceleratorType of v4-1024 can now be configured as: 4x4x32, 4x8x16, or 8x8x8.
TPU v2 and v3 configurations
TPUs are available in the following configurations:
- A single TPU board
- A TPU Pod
- A TPU Pod slice
Single TPU Board
A single-board TPU configuration is a standalone board with four TPU chips (eight TensorCores) with no network connections to other TPU boards. Single board TPUs are not part of a TPU Pod configuration and do not occupy a portion of a TPU Pod.
TPU Pods and Slices
In a TPU Pod or TPU Pod slice, TPU chips are connected using a high-speed interconnect. Each TPU chip communicates directly with the other chips on the TPU device. The TPU software automatically handles distributing data to each TPU core in a Pod or slice. Pod slices are available with 32, 128, 512, 1024, or 2048 cores.
Cloud TPU VM Architectures
TPUs are designed to perform matrix operations quickly. Each TPU board is connected to a CPU-based host machine to perform operations that cannot be executed on the TPU. The host machines are responsible for loading data from Cloud Storage, preprocessing data, and sending data to the TPU.
In a TPU Pod, there is a TPU host for each TPU board.
How you interact with the TPU host (and the TPU board) depends upon the TPU VM architecture you're using: TPU Nodes or TPU VMs.
The TPU Node architecture consists of a user VM that communicates with the TPU host over gRPC. When using this architecture, you cannot directly access the TPU Host, making it difficult to debug training and TPU errors.
The TPU VM architecture enables you to SSH into the VM physically connected to the TPU device. You have root access to the VM, so you can run arbitrary code. You can access compiler and runtime debug logs and error messages.
Frameworks like JAX, PyTorch, and TensorFlow access TPUs using a shared library
libtpu that's present on every TPU VM.
libtpu includes the XLA
compiler, TPU runtime software, and the TPU driver.
With TPU VMs, instead of your Python code running on a user VM, it can run directly on the TPU Host.
For more information on TensorFlow and Cloud TPU VM, see the Cloud TPU VM user's guide.
The Cloud TPU Node system architecture was originally built for TensorFlow. The TPU hosts are inaccessible to the user and run a headless copy of TensorFlow server. They don't run Python or any user code not represented as a TensorFlow graph. User code runs in a separate, user VM that communicates with the TPU hosts over the gRPC network.
With TPU VMs, your PyTorch code runs directly on the TPU hosts.
For more information on PyTorch and Cloud TPU, see the PyTorch/XLA user guide.
PyTorch runs on the Cloud TPU node architecture using a library called XRT. XRT sends XLA graphs and runtime instructions over gRPC to be run on TensorFlow servers. A user VM is required for each TPU Host.
With TPU VMs, your JAX code runs directly on the TPU hosts.
For more information on running JAX on Cloud TPU, see JAX quickstart.
JAX on Cloud TPU Nodes runs similar to PyTorch in that a separate user VM is required for each host VM.