Scale ML workloads using Ray
This document provides details on how to run machine learning (ML) workloads with Ray and JAX on TPUs. There are two different modes for using TPUs with Ray: device-centric mode (PyTorch/XLA) and host-centric mode (JAX).
This document assumes that you already have a TPU environment set up. For more information, see the following resources:
- Cloud TPU: Set up the Cloud TPU environment and Manage TPU resources
- Google Kubernetes Engine (GKE): Deploy TPU workloads in GKE Autopilot or Deploy TPU workloads in GKE Standard
Device-centric mode (PyTorch/XLA)
Device-centric mode retains much of the programmatic style of classic PyTorch. In this mode, you add a new XLA device type, which works like any other PyTorch device. Each individual process interacts with one XLA device.
This mode is ideal if you are already familiar with PyTorch with GPUs and want to use similar coding abstractions.
The following sections describe how to run a PyTorch/XLA workload on one or more devices without using Ray, then how to run the same workload on multiple hosts using Ray.
Create a TPU
Create environment variables for TPU creation parameters:
export TPU_NAME=TPU_NAME export ZONE=europe-west4-b export ACCELERATOR_TYPE=v5p-8 export VERSION=v2-alpha-tpuv5
Environment variable descriptions
TPU_NAME
- The name for your new Cloud TPU.
ZONE
- The zone in which to create your Cloud TPU.
accelerator-type
- The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information, see TPU versions.
version
- The TPU software version you want to use. For more information, see TPU VM images.
Use the following command to create a v5p TPU VM with 8 cores:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$VERSION
Connect to the TPU VM using the following command:
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
If you're using GKE, see the KubeRay on GKE guide for setup information.
Install requirements
Run the following commands on your TPU VM to install required dependencies:
Save the following to a file, for example,
requirements.txt
:--find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/libtpu-wheels/index.html torch~=2.6.0 torch_xla[tpu]~=2.6.0 ray[default]==2.40.0
Run the following command to install required dependencies:
pip install -r requirements.txt
If you're running your workload on GKE, we recommend creating a Dockerfile that installs the required dependencies. For an example, see Run your workload on TPU slice nodes in the GKE documentation.
Run a PyTorch/XLA workload on a single device
The following example demonstrates how to create a XLA tensor on a single device, which is a TPU chip. This is similar to how PyTorch handles other device types.
Save the following code snippet to a file, for example,
workload.py
:import torch import torch_xla import torch_xla.core.xla_model as xm t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t)
The
import torch_xla
import statement initializes PyTorch/XLA, and thexm.xla_device()
function returns the current XLA device, a TPU chip.Set the
PJRT_DEVICE
environment variable to TPU:export PJRT_DEVICE=TPU
Run the script:
python workload.py
The output looks similar to the following. Make sure that the output indicates that the XLA device is found.
xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
Run PyTorch/XLA on multiple devices
Update the code snippet from the previous section to run on multiple devices:
import torch import torch_xla import torch_xla.core.xla_model as xm def _mp_fn(index): t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) if __name__ == '__main__': torch_xla.launch(_mp_fn, args=())
Run the script:
python workload.py
If you run the code snippet on a TPU v5p-8, the output looks similar to the following:
xla:0 xla:0 xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') xla:0 tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0') tensor([[ 1.2309, 0.9896], [ 0.5820, -1.2950]], device='xla:0')
torch_xla.launch()
takes two arguments, a function, and a list of parameters.
It creates a process for each available XLA device and calls the function
specified in the arguments. In this example, there are 4 TPU devices available,
so torch_xla.launch()
creates 4 processes and calls _mp_fn()
on each device.
Each process only has access to one device, so each device has the index 0, and
xla:0
is printed for all processes.
Run PyTorch/XLA on multiple hosts with Ray
The following sections show how to run the same code snippet on a larger multi-host TPU slice. For more information about the multi-host TPU architecture, see System architecture.
In this example, you will manually set up Ray. If you are already familiar with setting up Ray, you can skip to the last section, Run a Ray workload. For more information about setting up Ray for a production environment, see the following resources:
Create a multi-host TPU VM
Create environment variables for TPU creation parameters:
export TPU_NAME_MULTIHOST=TPU_NAME_MULTIHOST export ZONE=europe-west4-b export ACCELERATOR_TYPE_MULTIHOST=v5p-16 export VERSION=v2-alpha-tpuv5
Create a multi-host TPU v5p with 2 hosts (a v5p-16, with 4 TPU chips on each host) using the following command:
gcloud compute tpus tpu-vm create $TPU_NAME_MULTIHOST \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE_MULTIHOST \ --version=$VERSION
Set up Ray
A TPU v5p-16 has 2 TPU hosts, each with 4 TPU chips. In this example, you will start the Ray head node on one host and add the second host as a worker node to the Ray cluster.
Connect to the first host using SSH:
gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=0
Install dependencies with the same requirements file as in the Install requirements section:
pip install -r requirements.txt
Start the Ray process:
ray start --head --port=6379
The output looks similar to the following:
Enable usage stats collection? This prompt will auto-proceed in 10 seconds to avoid blocking cluster startup. Confirm [Y/n]: y Usage stats collection is enabled. To disable this, add `--disable-usage-stats` to the command that starts the cluster, or run the following command: `ray disable-usage-stats` before starting the cluster. See https://docs.ray.io/en/master/cluster/usage-stats.html for more details. Local node IP: 10.130.0.76 -------------------- Ray runtime started. -------------------- Next steps To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379' To connect to this Ray cluster: import ray ray.init() To terminate the Ray runtime, run ray stop To view the status of the cluster, use ray status
This TPU host is now the Ray head node. Make a note of the lines that show how to add another node to the Ray cluster, similar to the following:
To add another node to this Ray cluster, run ray start --address='10.130.0.76:6379'
You will use this command in a later step.
Check the Ray cluster status:
ray status
The output looks similar to the following:
======== Autoscaler status: 2025-01-14 22:03:39.385610 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/208.0 CPU 0.0/4.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/268.44GiB memory 0B/119.04GiB object_store_memory 0.0/1.0 your-tpu-name Demands: (no resource demands)
The cluster only contains 4 TPUs (
0.0/4.0 TPU
) because you've only added the head node so far.
Now that the head node is running, you can add the second host to the cluster.
Connect to the second host using SSH:
gcloud compute tpus tpu-vm ssh $TPU_NAME_MULTIHOST --zone=$ZONE --worker=1
Install dependencies with the same requirements file as in the Install requirements section:
pip install -r requirements.txt
Start the Ray process. Use the command from the output of the
ray start
command to add this node to the existing Ray cluster. Make sure to replace the IP address and port in the following command:ray start --address='10.130.0.76:6379'
The output looks similar to the following:
Local node IP: 10.130.0.80 [2025-01-14 22:30:07,397 W 75572 75572] global_state_accessor.cc:463: Retrying to get node with node ID 35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 -------------------- Ray runtime started. -------------------- To terminate the Ray runtime, run ray stop
Check the Ray status again:
ray status
The output looks similar to the following:
======== Autoscaler status: 2025-01-14 22:45:21.485617 ======== Node status --------------------------------------------------------------- Active: 1 node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 1 node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1 Pending: (no pending nodes) Recent failures: (no failures) Resources --------------------------------------------------------------- Usage: 0.0/416.0 CPU 0.0/8.0 TPU 0.0/1.0 TPU-v5p-16-head 0B/546.83GiB memory 0B/238.35GiB object_store_memory 0.0/2.0 your-tpu-name Demands: (no resource demands)
The second TPU host is now a node in the cluster. The list of available resources now shows 8 TPUs (
0.0/8.0 TPU
).
Run a Ray workload
Update the code snippet to run on the Ray cluster:
import os import torch import torch_xla import torch_xla.core.xla_model as xm import ray import torch.distributed as dist import torch_xla.runtime as xr from torch_xla._internal import pjrt # Defines the local PJRT world size, the number of processes per host LOCAL_WORLD_SIZE = 4 # Defines the number of hosts in the Ray cluster NUM_OF_HOSTS = 2 GLOBAL_WORLD_SIZE = LOCAL_WORLD_SIZE * NUM_OF_HOSTS def init_env(): local_rank = int(os.environ['TPU_VISIBLE_CHIPS']) pjrt.initialize_multiprocess(local_rank, LOCAL_WORLD_SIZE) xr._init_world_size_ordinal() # This decorator signals to Ray that the print_tensor() function should be run on a single TPU chip @ray.remote(resources={"TPU": 1}) def print_tensor(): # Initializes the runtime environment on each Ray worker. Equivalent to # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section. init_env() t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t) ray.init() # Uses Ray to dispatch the function call across available nodes in the cluster tasks = [print_tensor.remote() for _ in range(GLOBAL_WORLD_SIZE)] ray.get(tasks) ray.shutdown()
Run the script on the Ray head node. Replace ray-workload.py with the path to your script.
python ray-workload.py
The output is similar to the following:
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. xla:0 xla:0 xla:0 xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0') xla:0 tensor([[ 0.6220, -1.4707], [-1.2112, 0.7024]], device='xla:0')
The output indicates that the function was successfully called on each XLA device (8 devices in this example) in the multi-host TPU slice.
Host-centric mode (JAX)
The following sections describe the host-centric mode with JAX. JAX utilizes a functional programming paradigm and supports higher-level single program, multiple data (SPMD) semantics. Instead of having each process interact with a single XLA device, JAX code is designed to operate across multiple devices on a single host concurrently.
JAX is designed for high performance computing and can efficiently utilize TPUs for large-scale training and inference. This mode is ideal if you're familiar with functional programming concepts so that you can take advantage of JAX's full potential.
These instructions assume that you already have a Ray and TPU environment set up, including a software environment that includes JAX and other related packages. To create a Ray TPU cluster, follow the instructions at Start Google Cloud GKE Cluster with TPUs for KubeRay. For more information about using TPUs with KubeRay, see Use TPUs with KubeRay.
Run a JAX workload on a single-host TPU
The following example script demonstrates how to run a JAX function on a Ray cluster with a single-host TPU, such as a v6e-4. If you have a multi-host TPU, this script stops responding due to JAX's multi-controller execution model. For more information about running Ray on a multi-host TPU, see Run a JAX workload on a multi-host TPU.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
h = my_function.remote()
print(ray.get(h)) # => 4
If you're used to running Ray with GPUs, there are some key differences when using TPUs:
- Rather than setting
num_gpus
, you specifyTPU
as a custom resource and set the number of TPU chips. - You specify the TPU using the number of chips per Ray worker node. For
example, if you're using a v6e-4, running a remote function with
TPU
set to 4 consumes the entire TPU host.- This is different from how GPUs typically run, with one process per host.
Setting
TPU
to a number that isn't 4 is not recommended. - Exception: If you have a single-host
v6e-8
orv5litepod-8
, you should set this value to 8.
- This is different from how GPUs typically run, with one process per host.
Setting
Run a JAX workload on a multi-host TPU
The following example script demonstrates how to run a JAX function on a Ray cluster with a multi-host TPU. The example script uses a v6e-16.
import ray
import jax
@ray.remote(resources={"TPU": 4})
def my_function() -> int:
return jax.device_count()
num_tpus = ray.available_resources()["TPU"]
num_hosts = int(num_tpus) // 4
h = [my_function.remote() for _ in range(num_hosts)]
print(ray.get(h)) # [16, 16, 16, 16]
If you're used to running Ray with GPUs, there are some key differences when using TPUs:
- Similar to PyTorch workloads on GPUs:
- JAX workloads on TPUs run in a multi-controller, single program, multiple data (SPMD) fashion.
- Collectives between devices are handled by the machine learning framework.
- Unlike PyTorch workloads on GPUs, JAX has a global view of the available devices in the cluster.
Run a Multislice JAX workload
Multislice lets you run workloads that span multiple TPU slices within a single TPU Pod or in multiple Pods over the data center network.
You can use the
ray-tpu
package to simplify
Ray's interactions with TPU slices. Install ray-tpu
using pip
:
pip install ray-tpu
The following example script shows how to use the ray-tpu
package to run
Multislice workloads using Ray actors or tasks:
from ray_tpu import RayTpuManager
import jax
import ray
ray.init()
# note - don't set resources as they will be overridden
@ray.remote
class MyActor:
def get_devices(self):
return jax.device_count()
# note - don't set resources as they will be overridden
@ray.remote
def get_devices() -> int:
return jax.device_count()
tpus = RayTpuManager.get_available_resources()
print("TPU resources: ", tpus)
"""
TPU resources:
{'v6e-16': [
RayTpu(name='tpu-group-1', num_hosts=4, head_ip='10.36.3.5', topology='v6e-16'),
RayTpu(name='tpu-group-0', num_hosts=4, head_ip='10.36.10.7', topology='v6e-16')
]}
"""
# if using actors
actors = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=MyActor,
multislice=True,
)
h = [actor.get_devices.remote() for actor in actors]
ray.get(h) # => [32, 32, 32, 32, 32, 32, 32, 32]
# if using tasks
h = RayTpuManager.remote(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=True,
)
ray.get(h) # [32, 32, 32, 32, 32, 32, 32, 32]
# note - you can also run this without Multislice
h = RayTpuManager.run_task(
tpus=tpus["v6e-16"],
actor_or_fn=get_devices,
multislice=False,
)
ray.get(h) # => [16, 16, 16, 16, 16, 16, 16, 16]
Orchestrate workloads using Ray and MaxText
This section describes how to use Ray to orchestrate workloads using MaxText, a scalable and high performance open source library for training LLMs using JAX and XLA.
MaxText contains a training script, train.py
, which needs to run on each TPU
host. This is similar to other SPMD machine learning workloads. You can achieve
this using the
ray-tpu
package and creating
a wrapper around the train.py
main function. The following steps show how to
use the ray-tpu
package to run MaxText on a TPU v4-16.
Set environment variables for TPU creation parameters:
export TPU_NAME=TPU_NAME export ZONE=ZONE export ACCELERATOR_TYPE=v6e-16 export VERSION=v2-alpha-tpuv6e
Create a TPU v6e-16:
gcloud compute tpus tpu-vm create $TPU_NAME \ --zone=$ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$VERSION
Clone the MaxText repository on all TPU workers:
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="git clone https://github.com/AI-Hypercomputer/maxtext"
Install MaxText requirements on all TPU workers:
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="pip install -r maxtext/requirements.txt"
Install the
ray-tpu
package on all TPU workers:gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=all \ --command="pip install ray-tpu"
Connect to worker 0 using SSH:
gcloud compute tpus tpu-vm ssh $TPU_NAME \ --zone=$ZONE \ --worker=0
Save the following script to a file named
ray_trainer.py
in the~/maxtext/MaxText
directory. This script uses theray-tpu
package and creates a wrapper around MaxText'strain.py
main function.import ray import ray_tpu from train import main as maxtext_main import logging from typing import Sequence from absl import app # Default env vars that run on all TPU VMs. MACHINE_ENV_VARS = { "ENABLE_PJRT_COMPATIBILITY": "true", "TPU_SLICE_BUILDER_DUMP_CHIP_FORCE": "true", "TPU_SLICE_BUILDER_DUMP_ICI": "true", "XLA_FLAGS": "--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto", # Dumps HLOs for debugging } def setup_loggers(): """Sets up loggers for Ray.""" logging.basicConfig(level=logging.INFO) @ray_tpu.remote( topology={"v4-16": 1}, ) def run_maxtext_train(argv: Sequence[str]): maxtext_main(argv=argv) def main(argv: Sequence[str]): ray.init(runtime_env=dict(worker_process_setup_hook=setup_loggers)) logging.info(f"argv: {argv}") try: ray.get(run_maxtext_train(argv=argv)) except Exception as e: logging.error("Caught error during training: %s", e) logging.error("Shutting down...") ray.shutdown() raise e logging.info("Training complete!") ray.shutdown() if __name__ == "__main__": logger = logging.getLogger() logger.setLevel(logging.INFO) app.run(main)
Execute the script by running the following command:
python maxtext/MaxText/ray_trainer.py maxtext/MaxText/configs/base.yml \ base_output_directory=/tmp/maxtext \ dataset_type=synthetic \ per_device_batch_size=2 \ max_target_length=8192 \ model_name=default \ steps=100 \ run_name=test
The output looks similar to the following:
(run_maxtext_train pid=78967, ip=10.130.0.11) Started an asynchronous checkpoint save for step 0 (run_maxtext_train pid=78967, ip=10.130.0.11) (run_maxtext_train pid=78967, ip=10.130.0.11) Memstats: After params initialized: (run_maxtext_train pid=78967, ip=10.130.0.11) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_4(process=1,(0,0,1,0)) (run_maxtext_train pid=78967, ip=10.130.0.11) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_5(process=1,(1,0,1,0)) (run_maxtext_train pid=78967, ip=10.130.0.11) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_6(process=1,(0,1,1,0)) (run_maxtext_train pid=78967, ip=10.130.0.11) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_7(process=1,(1,1,1,0)) (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 0, seconds: 11.775, TFLOP/s/device: 13.153, Tokens/s/device: 1391.395, total_weights: 131072, loss: 12.066 (run_maxtext_train pid=80538, ip=10.130.0.12) (run_maxtext_train pid=80538, ip=10.130.0.12) To see full metrics 'tensorboard --logdir=/tmp/maxtext/test/tensorboard/' (run_maxtext_train pid=80538, ip=10.130.0.12) Waiting for step 0 to finish before checkpoint... (run_maxtext_train pid=80538, ip=10.130.0.12) Waited 0.7087039947509766 seconds for step 0 to finish before starting checkpointing. (run_maxtext_train pid=80538, ip=10.130.0.12) Started an asynchronous checkpoint save for step 0 (run_maxtext_train pid=80538, ip=10.130.0.12) Memstats: After params initialized: (run_maxtext_train pid=80538, ip=10.130.0.12) Using (GB) 1.59 / 30.75 (5.170732%) on TPU_3(process=0,(1,1,0,0)) [repeated 4x across cluster] (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 4, seconds: 1.116, TFLOP/s/device: 138.799, Tokens/s/device: 14683.240, total_weights: 131072, loss: 0.000 [repeated 9x across cluster] (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 9, seconds: 1.068, TFLOP/s/device: 145.065, Tokens/s/device: 15346.083, total_weights: 131072, loss: 0.000 [repeated 9x across cluster] (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 14, seconds: 1.116, TFLOP/s/device: 138.754, Tokens/s/device: 14678.439, total_weights: 131072, loss: 0.000 [repeated 10x across cluster] ... (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 89, seconds: 1.116, TFLOP/s/device: 138.760, Tokens/s/device: 14679.083, total_weights: 131072, loss: 0.000 [repeated 10x across cluster] (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 94, seconds: 1.091, TFLOP/s/device: 141.924, Tokens/s/device: 15013.837, total_weights: 131072, loss: 0.000 [repeated 10x across cluster] (run_maxtext_train pid=78967, ip=10.130.0.11) completed step: 99, seconds: 1.116, TFLOP/s/device: 138.763, Tokens/s/device: 14679.412, total_weights: 131072, loss: 0.000 [repeated 10x across cluster] (run_maxtext_train pid=80538, ip=10.130.0.12) Output size: 1657041920, temp size: 4907988480, argument size: 1657366016, host temp size: 0, in bytes. I0121 01:39:46.830807 130655182204928 ray_trainer.py:47] Training complete! (run_maxtext_train pid=80538, ip=10.130.0.12) completed step: 99, seconds: 1.191, TFLOP/s/device: 130.014, Tokens/s/device: 13753.874, total_weights: 131072, loss: 0.000
TPU and Ray resources
Ray treats TPUs differently from GPUs to accommodate for the difference in usage. In the following example, there are nine Ray nodes total:
- The Ray head node is running on an
n1-standard-16
VM. - The Ray worker nodes are running on two
v6e-16
TPUs. Each TPU constitutes four workers.
$ ray status
======== Autoscaler status: 2024-10-17 09:30:00.854415 ========
Node status
---------------------------------------------------------------
Active:
1 node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2
1 node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2
1 node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63
1 node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1
1 node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91
1 node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867
1 node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393
1 node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a
1 node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/727.0 CPU
0.0/32.0 TPU
0.0/2.0 TPU-v6e-16-head
0B/5.13TiB memory
0B/1.47TiB object_store_memory
0.0/4.0 tpu-group-0
0.0/4.0 tpu-group-1
Demands:
(no resource demands)
Resource usage field descriptions:
CPU
: The total number of CPUs available in the cluster.TPU
: The number of TPU chips in the cluster.TPU-v6e-16-head
: A special identifier for the resource that corresponds with worker 0 of a TPU slice. This is important for accessing individual TPU slices.memory
: Worker heap memory used by your application.object_store_memory
: Memory used when your application creates objects in the object store usingray.put
and when it returns values from remote functions.tpu-group-0
andtpu-group-1
: Unique identifiers for the individual TPU slices. This is important for running jobs on slices. These fields are set to 4 because there are 4 hosts per TPU slice in a v6e-16.