This page describes how to accelerate machine learning (ML) workloads by using Cloud TPU accelerators (TPUs) in Google Kubernetes Engine (GKE) Autopilot clusters. This guidance can help you to select the correct libraries for your ML application frameworks, set up your TPU workloads to run optimally on GKE, and monitor your workloads after deployment.
This page is for Platform admins and operators, Data and AI specialists, and Application developers who want to prepare and run ML workloads on TPUs. To learn more about the common roles, responsibilities, and example tasks that we reference in Google Cloud content, see Common GKE Enterprise user roles and tasks.
Before reading this page, ensure that you're familiar with the following resources:
How TPUs work in Autopilot
To use TPUs in Autopilot workloads, you specify the following in your workload manifest:
- The TPU version in the
spec.nodeSelector
field. - The TPU topology in the
spec.nodeSelector
field. The topology must be supported by the specified TPU version. - The number of TPU chips in the
spec.containers.resources.requests
and thespec.containers.resources.limits
fields.
When you deploy the workload, GKE provisions nodes that have the requested TPU configuration and schedules your Pods on the nodes. GKE places each workload on its own node so that each Pod can access the full resources of the node with minimized risk of disruption.
TPUs in Autopilot are compatible with the following capabilities:
Plan your TPU configuration
Before you use this guide to deploy TPU workloads, plan your TPU configuration based on your model and how much memory it requires. For details, see Plan your TPU configuration.
Pricing
For pricing information, see Autopilot pricing.
Before you begin
Before you start, make sure you have performed the following tasks:
- Enable the Google Kubernetes Engine API. Enable Google Kubernetes Engine API
- If you want to use the Google Cloud CLI for this task,
install and then
initialize the
gcloud CLI. If you previously installed the gcloud CLI, get the latest
version by running
gcloud components update
.
- Ensure that you have an Autopilot cluster running GKE version 1.29.2-gke.1521000 or later.
- To use reserved TPUs, ensure that you have an existing specific capacity reservation. For instructions, see Consuming reserved zonal resources.
Ensure that you have TPU quota
The following sections help you ensure that you have enough quota when using TPUs in GKE. To create TPU slice nodes, you must have TPU quota available unless you're using an existing capacity reservation. If you're using reserved TPUs, skip this section.Creating TPU slice nodes in GKE requires Compute Engine API quota (compute.googleapis.com), not Cloud TPU API quota (tpu.googleapis.com). The name of the quota is different in regular Autopilot Pods and in Spot Pods.
To check the limit and current usage of your Compute Engine API quota for TPUs, follow these steps:
Go to the Quotas page in the Google Cloud console:
In the
Filter box, do the following:Select the Service property, enter Compute Engine API, and press Enter.
Select the Type property and choose Quota.
Select the Name property and enter the name of the quota based on the TPU version and value in the
cloud.google.com/gke-tpu-accelerator
node selector. For example, if you plan to create on-demand TPU v5e nodes whose value in thecloud.google.com/gke-tpu-accelerator
node selector istpu-v5-lite-podslice
, enterTPU v5 Lite PodSlice chips
.TPU version cloud.google.com/gke-tpu-accelerator
Name of the quota for on-demand instances Name of the quota for Spot2 instances TPU v3 tpu-v3-device
TPU v3 Device chips
Preemptible TPU v3 Device chips
TPU v3 tpu-v3-slice
TPU v3 PodSlice chips
Preemptible TPU v3 PodSlice chips
TPU v4 tpu-v4-podslice
TPU v4 PodSlice chips
Preemptible TPU v4 PodSlice chips
TPU v5e tpu-v5-lite-device
TPU v5 Lite Device chips
Preemptible TPU v5 Lite Device chips
TPU v5e tpu-v5-lite-podslice
TPU v5 Lite PodSlice chips
Preemptible TPU v5 Lite PodSlice chips
TPU v5p tpu-v5p-slice
TPU v5p chips
Preemptible TPU v5p chips
TPU Trillium tpu-v6e-slice
TPU v6e Slice chips
Preemptible TPU v6e Lite PodSlice chips
Select the Dimensions (e.g. locations) property and enter
region:
followed by the name of the region in which you plan to create TPUs in GKE. For example, enterregion:us-west4
if you plan to create TPU slice nodes in the zoneus-west4-a
. TPU quota is regional, so all zones within the same region consume the same TPU quota.
If no quotas match the filter you entered, then the project has not been granted any of the specified quota for the region that you need, and you must request a TPU quota increase.
When a TPU reservation is created, both the limit and current use values for
the corresponding quota increase by the number of chips in the TPU
reservation. For example, when a reservation is created for 16 TPU v5e chips
whose
value in the cloud.google.com/gke-tpu-accelerator
node selector is tpu-v5-lite-podslice
,
then both the Limit and
Current usage for the TPU v5 Lite PodSlice chips
quota in the relevant
region increase by 16.
Quotas for additional GKE resources
You may need to increase the following GKE-related quotas in the regions where GKE creates your resources.
- Persistent Disk SSD (GB) quota: The boot disk of each Kubernetes node requires 100GB by default. Therefore, this quota should be set at least as high as the product of the maximum number of GKE nodes you anticipate creating and 100GB (nodes * 100GB).
- In-use IP addresses quota: Each Kubernetes node consumes one IP address. Therefore, this quota should be set at least as high as the maximum number of GKE nodes you anticipate creating.
- Ensure that
max-pods-per-node
aligns with the subnet range: Each Kubernetes node uses secondary IP ranges for Pods. For example,max-pods-per-node
of 32 requires 64 IP addresses which translates to a /26 subnet per node. Note that this range shouldn't be shared with any other cluster. To avoid exhausting the IP address range, use the--max-pods-per-node
flag to limit the number of pods allowed to be scheduled on a node. The quota formax-pods-per-node
should be set at least as high as the maximum number of GKE nodes you anticipate creating.
To request an increase in quota, see Request higher quota.
Options for provisioning TPUs in GKE
GKE Autopilot lets you use TPUs directly in individual workloads by using Kubernetes nodeSelectors.
Alternatively, you can request TPUs by using custom compute classes. Custom compute classes let platform administrators define a hierarchy of node configurations for GKE to prioritize during node scaling decisions, so that workloads run on your selected hardware.
For instructions, see the Centrally provision TPUs with custom compute classes section.
Prepare your TPU application
TPU workloads have the following preparation requirements.
- Frameworks like JAX, PyTorch, and TensorFlow access TPU VMs using the
libtpu
shared library.libtpu
includes the XLA compiler, TPU runtime software, and the TPU driver. Each release of PyTorch and JAX requires a certainlibtpu.so
version. To use TPUs in GKE, ensure that you use the following versions:TPU type libtpu.so
versionTPU Trillium (v6e)
tpu-v6e-slice
- Recommended jax[tpu] version: v0.4.9 or later
- Recommended torchxla[tpuvm] version: v2.1.0 or later
TPU v5e
tpu-v5-lite-podslice
tpu-v5-lite-device
- Recommended jax[tpu] version: v0.4.9 or later
- Recommended torchxla[tpuvm] version: v2.1.0 or later
TPU v5p
tpu-v5p-slice
- Recommended jax[tpu] version: 0.4.19 or later.
- Recommended torchxla[tpuvm] version: suggested to use a nightly version build on October 23, 2023.
TPU v4
tpu-v4-podslice
- Recommended jax[tpu]: v0.4.4 or later
- Recommended torchxla[tpuvm]: v2.0.0 or later
TPU v3
tpu-v3-slice
tpu-v3-device
- Recommended jax[tpu]: v0.4.4 or later
- Recommended torchxla[tpuvm]: v2.0.0 or later
- Set the following environment variables for the container requesting the TPU resources:
TPU_WORKER_ID
: A unique integer for each Pod. This ID denotes a unique worker-id in the TPU slice. The supported values for this field range from zero to the number of Pods minus one.TPU_WORKER_HOSTNAMES
: A comma-separated list of TPU VM hostnames or IP addresses that need to communicate with each other within the slice. There should be a hostname or IP address for each TPU VM in the slice. The list of IP addresses or hostnames are ordered and zero indexed by theTPU_WORKER_ID
.
GKE automatically injects these environment variables by using a mutating webhook when a Job is created with the
completionMode: Indexed
,subdomain
,parallelism > 1
, and requestinggoogle.com/tpu
properties. GKE adds a headless Service so that the DNS records are added for the Pods backing the Service.
After you complete the workload preparation, you can run a Job that uses TPUs.
Request TPUs in a workload
This section shows you how to create a Job that requests TPUs in Autopilot. In any workload that needs TPUs, you must specify the following:
- Node selectors for the TPU version and topology
- The number of TPU chips for a container in your workload
For a list of supported TPU versions, topologies, and the corresponding number of TPU chips and nodes in a slice, see Choose an Autopilot TPU configuration.
Considerations for TPU requests in workloads
Only one container in a Pod can use TPUs. The number of TPU chips that a container
requests must be equal to the number of TPU chips attached to a node in the slice.
For example, if you request TPU v5e (tpu-v5-lite-podslice
) with a 2x4
topology, you can request any of the following:
4
chips, which creates two multi-host nodes with 4 TPU chips each8
chips, which creates one single-host node with 8 TPU chips
As a best practice to maximize your cost efficiency, always consume all of the TPU in the slice that you request. If you request a multi-host slice of two nodes with 4 TPU chips each, you should be deploying a workload that runs on both nodes and consumes all 8 TPU chips in the slice.
Create a workload that requests TPUs
The following steps create a Job that requests TPUs. If you have workloads that run on multi-host TPU slices, you must also create a headless Service that selects your workload by name. This headless Service lets Pods on different nodes in the multi-host slice to communicate with each other by updating the Kubernetes DNS configuration to point at the Pods in the workload.
Save the following manifest as
tpu-autopilot.yaml
:apiVersion: v1 kind: Service metadata: name: headless-svc spec: clusterIP: None selector: job-name: tpu-job --- apiVersion: batch/v1 kind: Job metadata: name: tpu-job spec: backoffLimit: 0 completions: 4 parallelism: 4 completionMode: Indexed template: spec: subdomain: headless-svc restartPolicy: Never nodeSelector: cloud.google.com/gke-tpu-accelerator: TPU_TYPE cloud.google.com/gke-tpu-topology: TOPOLOGY containers: - name: tpu-job image: python:3.10 ports: - containerPort: 8471 # Default port using which TPU VMs communicate - containerPort: 8431 # Port to export TPU runtime metrics, if supported. command: - bash - -c - | pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("TPU cores:", jax.device_count())' resources: requests: cpu: 10 memory: 500Gi google.com/tpu: NUMBER_OF_CHIPS limits: cpu: 10 memory: 500Gi google.com/tpu: NUMBER_OF_CHIPS
Replace the following:
TPU_TYPE
: the TPU type to use, liketpu-v4-podslice
. Must be a value supported by GKE.TOPOLOGY
: the arrangement of TPU chips in the slice, like2x2x4
. Must be a supported topology for the selected TPU type.NUMBER_OF_CHIPS
: the number of TPU chips for the container to use. Must be the same value forlimits
andrequests
.
Deploy the Job:
kubectl create -f tpu-autopilot.yaml
When you create this Job, GKE automatically does the following:
- Provisions nodes to run the Pods. Depending on the TPU type, topology, and resource requests that you specified, these nodes are either single-host slices or multi-host slices.
- Adds taints to the Pods and tolerations to the nodes to prevent any of your other workloads from running on the same nodes as TPU workloads.
Create a workload that requests TPUs and collection scheduling
In TPU Trillium, you can use collection scheduling to group TPU slice nodes. Grouping these TPU slice nodes makes it easier to adjust the number of replicas to meet the workload demand. Google Cloud controls software updates to ensure that sufficient slices within the collection are always available to serve traffic.
TPU Trillium supports collection scheduling for single-host and multi-host node pools that run inference workloads. The following describes how collection scheduling behavior depends on the type of TPU slice that you use:
- Multi-host TPU slice: GKE groups multi-host TPU slices to form a collection. Each GKE node pool is a replica within this collection. To define a collection, create a multi-host TPU slice and assign a unique name to the collection. To add more TPU slices to the collection, create another multi-host TPU slice node pool with the same collection name and workload type.
- Single-host TPU slice: GKE considers the entire single-host TPU slice node pool as a collection. To add more TPU slices to the collection, you can resize the single-host TPU slice node pool.
To learn about the limitation of collection scheduling, see How collection scheduling works
Use a multi-host TPU slice
Collection schedulling in multi-host TPU slice nodes is available for
Autopilot clusters in
version 1.31.2-gke.1537000 and later. Multi-host TPU slice nodes with a
2x4
topology are only supported in 1.31.2-gke.1115000 or later. To create
multi-host TPU slice nodes and group it as a collection, add the
following Kubernetes labels to your workload specification:
cloud.google.com/gke-nodepool-group-name
: each collection should have a unique name at the cluster level. The value in thecloud.google.com/gke-nodepool-group-name
label must adhere to requirements for cluster labels.cloud.google.com/gke-workload-type: HIGH_AVAILABILITY
For example, the following code block defines a collection with a multi-host TPU slice:
nodeSelector: cloud.google.com/gke-nodepool-group-name: ${COLLECTION_NAME} cloud.google.com/gke-workload-type: HIGH_AVAILABILITY cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice cloud.google.com/gke-tpu-topology: 4x4 ...
Use a single-host TPU slice
Collection schedulling in single-host TPU slice nodes is available for
Autopilot clusters in version 1.31.2-gke.1088000 and later. To create
single-host TPU slice nodes and group it as a collection, add the
cloud.google.com/gke-workload-type:HIGH_AVAILABILITY
label in your workload
specification.
For example, the following code block defines a collection with a single-host TPU slice:
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
cloud.google.com/gke-workload-type: HIGH_AVAILABILITY
...
Centrally provision TPUs with custom compute classes
To provision TPUs with a custom compute class, do the following:
Ensure that your cluster has an available custom compute class that selects TPUs. To learn how to specify TPUs in custom compute classes, see TPU rules.
Save the following manifest as
tpu-job.yaml
:apiVersion: v1 kind: Service metadata: name: headless-svc spec: clusterIP: None selector: job-name: tpu-job --- apiVersion: batch/v1 kind: Job metadata: name: tpu-job spec: backoffLimit: 0 completions: 4 parallelism: 4 completionMode: Indexed template: spec: subdomain: headless-svc restartPolicy: Never nodeSelector: cloud.google.com/compute-class: TPU_CLASS_NAME containers: - name: tpu-job image: python:3.10 ports: - containerPort: 8471 # Default port using which TPU VMs communicate - containerPort: 8431 # Port to export TPU runtime metrics, if supported. command: - bash - -c - | pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("TPU cores:", jax.device_count())' resources: requests: cpu: 10 memory: 500Gi google.com/tpu: NUMBER_OF_CHIPS limits: cpu: 10 memory: 500Gi google.com/tpu: NUMBER_OF_CHIPS
Replace the following:
TPU_CLASS_NAME
: the name of the existing custom compute class that specifies TPUs.NUMBER_OF_CHIPS
: the number of TPU chips for the container to use. Must be the same value forlimits
andrequests
, equal to the value in thetpu.count
field in the selected custom compute class.
Deploy the Job:
kubectl create -f tpu-workload.yaml
When you create this Job, GKE automatically does the following:
- Provisions nodes to run the Pods. Depending on the TPU type, topology, and resource requests that you specified, these nodes are either single-host slices or multi-host slices. Depending on the availability of TPU resources in the top priority, GKE might fall back to lower priorities to maximize obtainability.
- Adds taints to the Pods and tolerations to the nodes to prevent any of your other workloads from running on the same nodes as TPU workloads.
To learn more, see About custom compute classes.
Example: Display the total TPU chips in a multi-host slice
The following workload returns the number of TPU chips across all of the nodes in a multi-host TPU slice. To create a multi-host slice, the workload has the following parameters:
- TPU version: TPU v4
- Topology: 2x2x4
This version and topology selection result in a multi-host slice.
- Save the following manifest as
available-chips-multihost.yaml
:apiVersion: v1 kind: Service metadata: name: headless-svc spec: clusterIP: None selector: job-name: tpu-available-chips --- apiVersion: batch/v1 kind: Job metadata: name: tpu-available-chips spec: backoffLimit: 0 completions: 4 parallelism: 4 completionMode: Indexed template: spec: subdomain: headless-svc restartPolicy: Never nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice cloud.google.com/gke-tpu-topology: 2x2x4 containers: - name: tpu-job image: python:3.10 ports: - containerPort: 8471 # Default port using which TPU VMs communicate - containerPort: 8431 # Port to export TPU runtime metrics, if supported. command: - bash - -c - | pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("TPU cores:", jax.device_count())' resources: requests: cpu: 10 memory: 500Gi google.com/tpu: 4 limits: cpu: 10 memory: 500Gi google.com/tpu: 4
- Deploy the manifest:
kubectl create -f available-chips-multihost.yaml
GKE runs a TPU v4 slice with four VMs (multi-host TPU slice). The slice has 16 interconnected TPU chips.
- Verify that the Job created four Pods:
kubectl get pods
The output is similar to the following:
NAME READY STATUS RESTARTS AGE tpu-job-podslice-0-5cd8r 0/1 Completed 0 97s tpu-job-podslice-1-lqqxt 0/1 Completed 0 97s tpu-job-podslice-2-f6kwh 0/1 Completed 0 97s tpu-job-podslice-3-m8b5c 0/1 Completed 0 97s
- Get the logs of one of the Pods:
kubectl logs POD_NAME
Replace
POD_NAME
with the name of one of the created Pods. For example,tpu-job-podslice-0-5cd8r
.The output is similar to the following:
TPU cores: 16
Example: Display the TPU chips in a single node
The following workload is a static Pod that displays the number of TPU chips that are attached to a specific node. To create a single-host node, the workload has the following parameters:
- TPU version: TPU v5e
- Topology: 2x4
This version and topology selection result in a single-host slice.
- Save the following manifest as
available-chips-singlehost.yaml
:apiVersion: v1 kind: Pod metadata: name: tpu-job-jax-v5 spec: restartPolicy: Never nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice cloud.google.com/gke-tpu-topology: 2x4 containers: - name: tpu-job image: python:3.10 ports: - containerPort: 8431 # Port to export TPU runtime metrics, if supported. command: - bash - -c - | pip install 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html python -c 'import jax; print("Total TPU chips:", jax.device_count())' resources: requests: google.com/tpu: 8 limits: google.com/tpu: 8
- Deploy the manifest:
kubectl create -f available-chips-singlehost.yaml
GKE provisions nodes with eight single-host TPU slices that use TPU v5e. Each TPU node has eight TPU chips (single-host TPU slice).
- Get the logs of the Pod:
kubectl logs tpu-job-jax-v5
The output is similar to the following:
Total TPU chips: 8
Observability and metrics
Dashboard
In the Kubernetes Clusters page in the Google Cloud console, the Observability tab displays the TPU observability metrics. For more information, see GKE observability metrics.
The TPU dashboard is populated only if you have system metrics enabled in your GKE cluster.
Runtime metrics
In GKE version 1.27.4-gke.900 or later, TPU workloads
that use JAX version
0.4.14
or later and specify containerPort: 8431
export TPU utilization metrics as GKE
system metrics.
The following metrics are available in Cloud Monitoring
to monitor your TPU workload's runtime performance:
- Duty cycle: Percentage of time over the past sampling period (60 seconds) during which the TensorCores were actively processing on a TPU chip. Larger percentage means better TPU utilization.
- Memory used: Amount of accelerator memory allocated in bytes. Sampled every 60 seconds.
- Memory total: Total accelerator memory in bytes. Sampled every 60 seconds.
These metrics are located in the Kubernetes node (k8s_node
) and Kubernetes
container (k8s_container
) schema.
Kubernetes container:
kubernetes.io/container/accelerator/duty_cycle
kubernetes.io/container/accelerator/memory_used
kubernetes.io/container/accelerator/memory_total
Kubernetes node:
kubernetes.io/node/accelerator/duty_cycle
kubernetes.io/node/accelerator/memory_used
kubernetes.io/node/accelerator/memory_total
Host metrics
In GKE version 1.28.1-gke.1066000 or later, VMs in a TPU slice export TPU utilization metrics as GKE system metrics. The following metrics are available in Cloud Monitoring to monitor your TPU host's performance:
- TensorCore utilization: Current percentage of the TensorCore that is utilized. The TensorCore value equals the sum of the matrix-multiply units (MXUs) plus the vector unit. The TensorCore utilization value is the division of the TensorCore operations that were performed over the past sample period (60 seconds) by the supported number of TensorCore operations over the same period. Larger value means better utilization.
- Memory Bandwidth utilization: Current percentage of the accelerator memory bandwidth that is being used. Computed by dividing the memory bandwidth used over a sample period (60s) by the maximum supported bandwidth over the same sample period.
These metrics are located in the Kubernetes node (k8s_node
) and Kubernetes
container (k8s_container
) schema.
Kubernetes container:
kubernetes.io/container/accelerator/tensorcore_utilization
kubernetes.io/container/accelerator/memory_bandwidth_utilization
Kubernetes node:
kubernetes.io/container/node/tensorcore_utilization
kubernetes.io/container/node/memory_bandwidth_utilization
For more information, see Kubernetes metrics and GKE system metrics.
Logging
Logs emitted by containers running on GKE nodes, including TPU VMs, are collected by the GKE logging agent, sent to Logging, and are visible in Logging.
Recommendations for TPU workloads in Autopilot
The following recommendations might improve the efficiency of your TPU workloads:
- Use extended run time Pods for a grace period of up to seven days before GKE terminates your Pods for scale-downs or node upgrades. You can use maintenance windows and exclusions with extended run time Pods to further delay automatic node upgrades.
- Use capacity reservations to ensure that your workloads receive requested TPUs without being placed in a queue for availability.