This guide describes how to serve open models using Hex-LLM, a high-performance serving framework for Cloud TPU. This document covers the following topics: Hex-LLM (High-efficiency large language model serving with XLA) is a Vertex AI LLM serving framework designed and optimized for Cloud TPU hardware. Hex-LLM combines LLM serving technologies such as continuous batching and PagedAttention with Vertex AI optimizations that are tailored for XLA and Cloud TPU. It provides high-efficiency and low-cost LLM serving on Cloud TPU for open source models. You can use Hex-LLM in Model Garden through the model playground, one-click deployment, or notebooks. Hex-LLM is based on open source projects and includes Google's optimizations for XLA and Cloud TPU. This design helps Hex-LLM achieve high throughput and low latency when serving frequently used LLMs. Hex-LLM includes the following optimizations: Hex-LLM supports a wide range of dense and sparse LLMs: Hex-LLM also provides the following features: Hex-LLM supports the following advanced features: Hex-LLM supports serving models with a multi-host TPU slice. This feature lets you serve large models that cannot be loaded into a single host TPU VM, which contains a maximum of eight v5e cores. To enable this feature, set For an end-to-end tutorial for deploying the Hex-LLM container with a multi-host TPU topology, see the Vertex AI Model Garden - Llama 3.1 (Deployment) notebook. To enable multi-host serving, you generally only need to make the following changes: Hex-LLM supports disaggregated serving as an experimental feature. This feature can only be enabled on a single-host setup, and its performance is still being tuned. Disaggregated serving is a method for balancing Time to First Token (TTFT) and Time Per Output Token (TPOT) for each request, and for balancing overall serving throughput. It separates the prefill and decode phases into different workloads so that they don't interfere with each other. This method is useful for scenarios with strict latency requirements. To enable this feature, set The Prefix caching reduces Time to First Token (TTFT) for prompts that have identical content at the beginning of the prompt, such as company-wide preambles, common system instructions, and multi-turn conversation history. Instead of processing the same input tokens repeatedly, Hex-LLM can cache the computations for processed input tokens to improve TTFT. To enable this feature, set Hex-LLM uses prefix caching to optimize performance for prompts that exceed a certain length. The default length is 512 tokens and is configurable using Chunked prefill splits a request prefill into smaller chunks and mixes prefill and decode operations into one batch step. Hex-LLM implements chunked prefill to balance the Time to First Token (TTFT) and Time per Output Token (TPOT) and to improve throughput. To enable this feature, set Quantization is a technique for reducing the computational and memory costs of running inference by representing the weights or activations with low-precision data types like INT8 or INT4 instead of the typical BF16 or FP32 data types. Hex-LLM supports INT8 weight-only quantization. It also supports models with INT4 weights that are quantized using AWQ zero-point quantization. Hex-LLM supports INT4 variants of the Mistral, Mixtral, and Llama model families. No additional flag is required to serve quantized models. The Hex-LLM Cloud TPU serving container is integrated with Model Garden. You can use this serving technology through the playground, one-click deployment, and Colab Enterprise notebook examples for various models. The Model Garden playground provides a pre-deployed Vertex AI endpoint that you can access from the model card. You can deploy a custom Vertex AI endpoint with Hex-LLM using a model card. For flexibility and customization, you can use Colab Enterprise notebook examples to deploy a Vertex AI endpoint with Hex-LLM using the Vertex AI SDK for Python. Example Colab Enterprise notebooks include: You can set the following arguments to launch the Hex-LLM server and tailor them to your use case. The arguments are predefined for one-click deployment. To customize the arguments, you can use the notebook examples as a reference and set the arguments accordingly. Model Inference engine Memory management Dynamic LoRA You can also configure the server using the following environment variables: The server arguments are interrelated and collectively affect serving performance. For example, a larger The following is a sample set of arguments for deploying Llama 3.1 8B Instruct: The following is a sample set of arguments for deploying Llama 3.1 70B Instruct AWQ on In Model Garden, your default quota is 32 Cloud TPU v5e chips in the
Features
Advanced features
Feature
Description
Use Case
Multi-host serving
Serves models across multiple TPU host VMs.
For large models that do not fit into a single host's memory (for example, Llama 3.1 70B).
Disaggregated serving
Separates the prefill and decode phases into different workloads to balance latency and throughput.
For applications with strict latency requirements where balancing Time to First Token (TTFT) and Time Per Output Token (TPOT) is critical.
Prefix caching
Caches computations for common prompt prefixes to speed up processing.
To reduce Time to First Token (TTFT) for prompts with recurring initial content, like system instructions or conversation history.
Chunked prefill
Splits long prompt prefills into smaller chunks to balance latency and improve throughput.
To improve performance and support longer context lengths by mixing prefill and decode steps.
4-bit quantization
Reduces memory and computational costs by using low-precision data types for model weights.
To serve large models more efficiently with lower resource requirements.
Multi-host serving
--num_hosts
in the Hex-LLM container arguments and set --tpu_topology
in the Vertex AI SDK model upload request. The following example shows how to deploy the Hex-LLM container with a TPU 4x4 v5e topology that serves the Llama 3.1 70B bfloat16 model:hexllm_args = [
"--host=0.0.0.0",
"--port=7080",
"--model=meta-llama/Meta-Llama-3.1-70B",
"--data_parallel_size=1",
"--tensor_parallel_size=16",
"--num_hosts=4",
"--hbm_utilization_factor=0.9",
]
model = aiplatform.Model.upload(
display_name=model_name,
serving_container_image_uri=HEXLLM_DOCKER_URI,
serving_container_command=["python", "-m", "hex_llm.server.api_server"],
serving_container_args=hexllm_args,
serving_container_ports=[7080],
serving_container_predict_route="/generate",
serving_container_health_route="/ping",
serving_container_environment_variables=env_vars,
serving_container_shared_memory_size_mb=(16 * 1024), # 16 GB
serving_container_deployment_timeout=7200,
location=TPU_DEPLOYMENT_REGION,
)
model.deploy(
endpoint=endpoint,
machine_type=machine_type,
tpu_topology="4x4",
deploy_request_timeout=1800,
service_account=service_account,
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
)
--tensor_parallel_size
argument to the total number of cores within the TPU topology.--num_hosts
argument to the number of hosts within the TPU topology.--tpu_topology
with the Vertex AI SDK model upload API.Disaggregated serving [experimental]
--disagg_topo
in the Hex-LLM container arguments. The following example shows how to deploy the Hex-LLM container on TPU v5e-8 that serves the Llama 3.1 8B bfloat16 model:hexllm_args = [
"--host=0.0.0.0",
"--port=7080",
"--model=meta-llama/Llama-3.1-8B",
"--data_parallel_size=1",
"--tensor_parallel_size=2",
"--disagg_topo=3,1",
"--hbm_utilization_factor=0.9",
]
model = aiplatform.Model.upload(
display_name=model_name,
serving_container_image_uri=HEXLLM_DOCKER_URI,
serving_container_command=["python", "-m", "hex_llm.server.api_server"],
serving_container_args=hexllm_args,
serving_container_ports=[7080],
serving_container_predict_route="/generate",
serving_container_health_route="/ping",
serving_container_environment_variables=env_vars,
serving_container_shared_memory_size_mb=(16 * 1024), # 16 GB
serving_container_deployment_timeout=7200,
location=TPU_DEPLOYMENT_REGION,
)
model.deploy(
endpoint=endpoint,
machine_type=machine_type,
deploy_request_timeout=1800,
service_account=service_account,
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
)
--disagg_topo
argument accepts a string in the format "number_of_prefill_workers,number_of_decode_workers"
. In the preceding example, the argument is set to "3,1"
to configure three prefill workers and one decode worker. Each worker uses two TPU v5e cores.Prefix caching
--enable_prefix_cache_hbm
in the Hex-LLM container arguments. The following example shows how to deploy the Hex-LLM container on TPU v5e-8 that serves the Llama 3.1 8B bfloat16 model:hexllm_args = [
"--host=0.0.0.0",
"--port=7080",
"--model=meta-llama/Llama-3.1-8B",
"--data_parallel_size=1",
"--tensor_parallel_size=4",
"--hbm_utilization_factor=0.9",
"--enable_prefix_cache_hbm",
]
model = aiplatform.Model.upload(
display_name=model_name,
serving_container_image_uri=HEXLLM_DOCKER_URI,
serving_container_command=["python", "-m", "hex_llm.server.api_server"],
serving_container_args=hexllm_args,
serving_container_ports=[7080],
serving_container_predict_route="/generate",
serving_container_health_route="/ping",
serving_container_environment_variables=env_vars,
serving_container_shared_memory_size_mb=(16 * 1024), # 16 GB
serving_container_deployment_timeout=7200,
location=TPU_DEPLOYMENT_REGION,
)
model.deploy(
endpoint=endpoint,
machine_type=machine_type,
deploy_request_timeout=1800,
service_account=service_account,
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
)
prefill_len_padding
. Cache hits occur in increments of this value, so the cached token count is always a multiple of prefill_len_padding
. In the chat completion API response, the cached_tokens
field of usage.prompt_tokens_details
indicates how many of the prompt tokens were a cache hit."usage": {
"prompt_tokens": 643,
"total_tokens": 743,
"completion_tokens": 100,
"prompt_tokens_details": {
"cached_tokens": 512
}
}
Chunked prefill
--enable_chunked_prefill
in the Hex-LLM container arguments. The following example shows how to deploy the Hex-LLM container on TPU v5e-8 that serves the Llama 3.1 8B model:hexllm_args = [
"--host=0.0.0.0",
"--port=7080",
"--model=meta-llama/Llama-3.1-8B",
"--data_parallel_size=1",
"--tensor_parallel_size=4",
"--hbm_utilization_factor=0.9",
"--enable_chunked_prefill",
]
model = aiplatform.Model.upload(
display_name=model_name,
serving_container_image_uri=HEXLLM_DOCKER_URI,
serving_container_command=["python", "-m", "hex_llm.server.api_server"],
serving_container_args=hexllm_args,
serving_container_ports=[7080],
serving_container_predict_route="/generate",
serving_container_health_route="/ping",
serving_container_environment_variables=env_vars,
serving_container_shared_memory_size_mb=(16 * 1024), # 16 GB
serving_container_deployment_timeout=7200,
location=TPU_DEPLOYMENT_REGION,
)
model.deploy(
endpoint=endpoint,
machine_type=machine_type,
deploy_request_timeout=1800,
service_account=service_account,
min_replica_count=min_replica_count,
max_replica_count=max_replica_count,
)
4-bit quantization support
Get started in Model Garden
Method
Description
Best for
Playground
A pre-deployed endpoint for quick interaction with a model through a UI.
Quickly testing model responses without any setup.
One-click deployment
Deploys a model to a new Vertex AI endpoint using default configurations from the model card.
Easily creating a dedicated endpoint with minimal configuration.
Colab Enterprise notebook
Provides a notebook with code to deploy a model, allowing for full customization of deployment parameters.
Customizing deployment settings, experimenting with different configurations, and integrating into a larger workflow.
Use the playground
Use one-click deployment
Use the Colab Enterprise notebook
hexllm_args = [
f"--model=google/gemma-2-9b-it",
f"--tensor_parallel_size=4",
f"--hbm_utilization_factor=0.8",
f"--max_running_seqs=512",
]
hexllm_envs = {
"PJRT_DEVICE": "TPU",
"MODEL_ID": "google/gemma-2-9b-it",
"DEPLOY_SOURCE": "notebook",
}
model = aiplatform.Model.upload(
display_name="gemma-2-9b-it",
serving_container_image_uri=HEXLLM_DOCKER_URI,
serving_container_command=[
"python", "-m", "hex_llm.server.api_server"
],
serving_container_args=hexllm_args,
serving_container_ports=[7080],
serving_container_predict_route="/generate",
serving_container_health_route="/ping",
serving_container_environment_variables=hexllm_envs,
serving_container_shared_memory_size_mb=(16 * 1024),
serving_container_deployment_timeout=7200,
)
endpoint = aiplatform.Endpoint.create(display_name="gemma-2-9b-it-endpoint")
model.deploy(
endpoint=endpoint,
machine_type="ct5lp-hightpu-4t",
deploy_request_timeout=1800,
service_account="<your-service-account>",
min_replica_count=1,
max_replica_count=1,
)
Configure server arguments and environment variables
--model
: The model to load. You can specify a Hugging Face model ID, a Cloud Storage bucket path (gs://my-bucket/my-model
), or a local path. The model artifacts are expected to follow the Hugging Face format and use safetensors files for the model weights. BitsAndBytes int8 and AWQ quantized model artifacts are supported for the Llama, Gemma 2, and Mistral/Mixtral model families.--tokenizer
: The tokenizer to load. This can be a Hugging Face model ID, a Cloud Storage bucket path (gs://my-bucket/my-model
), or a local path. If this argument is not set, it defaults to the value for --model
.--tokenizer_mode
: The tokenizer mode. Possible choices are ["auto", "slow"]
. The default value is "auto"
. If this is set to "auto"
, the fast tokenizer is used if available. The slow tokenizers are written in Python and provided in the Transformers library, while the fast tokenizers, which offer performance improvements, are written in Rust and provided in the Tokenizers library. For more information, see the Hugging Face documentation.--trust_remote_code
: Whether to allow remote code files defined in the Hugging Face model repositories. The default value is False
.--load_format
: Format of model checkpoints to load. Possible choices are ["auto", "dummy"]
. The default value is "auto"
. If this is set to "auto"
, the model weights are loaded in safetensors format. If this is set to "dummy"
, the model weights are randomly initialized. Setting this to "dummy"
is useful for experimentation.--max_model_len
: The maximum context length (input length plus the output length) to serve for the model. The default value is read from the model configuration file in Hugging Face format: config.json
. A larger maximum context length requires more TPU memory.--sliding_window
: If set, this argument overrides the model's window size for sliding window attention. Setting this argument to a larger value causes the attention mechanism to include more tokens, which approaches the effect of standard self-attention. This argument is for experimental use only. For general use cases, use the model's original window size.--seed
: The seed for initializing all random number generators. Changing this argument might affect the generated output for the same prompt by changing the tokens that are sampled as next tokens. The default value is 0
.
--num_hosts
: The number of hosts to run. The default value is 1
. For more information, see TPU v5e configuration.--disagg_topo
: Defines the number of prefill workers and decode workers for the experimental disaggregated serving feature. The default value is None
. The argument follows the format: "number_of_prefill_workers,number_of_decode_workers"
.--data_parallel_size
: The number of data parallel replicas. The default value is 1
. Setting this to N
from 1
approximately improves the throughput by N
, while maintaining the same latency.--tensor_parallel_size
: The number of tensor parallel replicas. The default value is 1
. Increasing the number of tensor parallel replicas generally improves latency, because it speeds up matrix multiplication by reducing the matrix size.--worker_distributed_method
: The distributed method to launch the worker. Use mp
for the multiprocessing module or ray
for the Ray library. The default value is mp
.--enable_jit
: Whether to enable JIT (Just-in-Time Compilation) mode. The default value is True
. Setting --no-enable_jit
disables it. Enabling JIT mode improves inference performance at the cost of requiring additional time spent on initial compilation. In general, the inference performance benefits outweigh the overhead.--warmup
: Whether to warm up the server with sample requests during initialization. The default value is True
. Setting --no-warmup
disables it. Warmup is recommended because initial requests trigger compilation and are therefore slower.--max_prefill_seqs
: The maximum number of sequences that can be scheduled for prefilling per iteration. The default value is 1
. The larger this value is, the higher throughput the server can achieve, but with potential negative effects on latency.--prefill_seqs_padding
: The server pads the prefill batch size to a multiple of this value. The default value is 8
. Increasing this value reduces model recompilation times, but increases wasted computation and inference overhead. The optimal setting depends on the request traffic.--prefill_len_padding
: The server pads the sequence length to a multiple of this value. The default value is 512
. Increasing this value reduces model recompilation times, but increases wasted computation and inference overhead. The optimal setting depends on the data distribution of the requests.--max_decode_seqs
/--max_running_seqs
: The maximum number of sequences that can be scheduled for decoding per iteration. The default value is 256
. The larger this value is, the higher throughput the server can achieve, but with potential negative effects on latency.--decode_seqs_padding
: The server pads the decode batch size to a multiple of this value. The default value is 8
. Increasing this value reduces model recompilation times, but increases wasted computation and inference overhead. The optimal setting depends on the request traffic.--decode_blocks_padding
: The server pads the number of memory blocks used for a sequence's Key-Value cache (KV cache) to a multiple of this value during decoding. The default value is 128
. Increasing this value reduces model recompilation times, but increases wasted computation and inference overhead. The optimal setting depends on the data distribution of the requests.--enable_prefix_cache_hbm
: Whether to enable prefix caching in HBM. The default value is False
. Setting this argument can improve performance by reusing the computations of shared prefixes of prior requests.--enable_chunked_prefill
: Whether to enable chunked prefill. The default value is False
. Setting this argument can support longer context length and improve performance.
--hbm_utilization_factor
: The percentage of free Cloud TPU High Bandwidth Memory (HBM) that can be allocated for KV cache after model weights are loaded. The default value is 0.9
. Setting this argument to a higher value increases the KV cache size and can improve throughput, but it increases the risk of running out of Cloud TPU HBM during initialization and at runtime.--num_blocks
: Number of device blocks to allocate for KV cache. If this argument is set, the server ignores --hbm_utilization_factor
. If this argument is not set, the server profiles HBM usage and computes the number of device blocks to allocate based on --hbm_utilization_factor
. Setting this argument to a higher value increases the KV cache size and can improve throughput, but it increases the risk of running out of Cloud TPU HBM during initialization and at runtime.--block_size
: Number of tokens stored in a block. Possible choices are [8, 16, 32, 2048, 8192]
. The default value is 32
. Setting this argument to a larger value reduces overhead in block management, at the cost of more memory waste. The exact performance impact needs to be determined empirically.
--enable_lora
: Whether to enable dynamic LoRA adapters loading from Cloud Storage. The default value is False
. This is supported for the Llama model family.--max_lora_rank
: The maximum LoRA rank supported for LoRA adapters defined in requests. The default value is 16
. Setting this argument to a higher value allows for greater flexibility in the LoRA adapters that can be used with the server, but increases the amount of Cloud TPU HBM allocated for LoRA weights and decreases throughput.--enable_lora_cache
: Whether to enable caching of dynamic LoRA adapters. The default value is True
. Setting --no-enable_lora_cache
disables it. Caching improves performance because it removes the need to re-download previously used LoRA adapter files.--max_num_mem_cached_lora
: The maximum number of LoRA adapters stored in TPU memory cache.The default value is 16
. Setting this argument to a larger value improves the chance of a cache hit, but it increases the amount of Cloud TPU HBM usage.
HEX_LLM_LOG_LEVEL
: Controls the amount of logging information generated. The default value is INFO
. Set this to one of the standard Python logging levels defined in the logging module.HEX_LLM_VERBOSE_LOG
: Whether to enable detailed logging output. Allowed values are true
or false
. The default value is false
.Tune server arguments
--max_model_len
setting leads to higher TPU memory usage, which requires larger memory allocation and less batching. Some arguments are determined by the use case, while others can be tuned. This section provides a workflow for configuring the Hex-LLM server.
model_size * (num_bits / 8)
. For an 8B model and bfloat16 precision, the lower bound of TPU memory needed is 8 * (16 / 8) = 16 GB
.tpu_memory / 16
). For an 8B model and bfloat16 precision, you need more than one chip. Of the available 1-chip, 4-chip, and 8-chip configurations, the smallest configuration that offers more than one chip is the 4-chip configuration: ct5lp-hightpu-4t
. You can then set --tensor_parallel_size=4
.--max_model_len=4096
.--hbm_utilization_factor
). Start with 0.95
. Deploy the Hex-LLM server and test it with long prompts and high concurrency. If the server runs out of memory, reduce the utilization factor.python -m hex_llm.server.api_server \
--model=meta-llama/Llama-3.1-8B-Instruct \
--tensor_parallel_size=4 \
--max_model_len=4096
--hbm_utilization_factor=0.95
ct5lp-hightpu-4t
:python -m hex_llm.server.api_server \
--model=hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4 \
--tensor_parallel_size=4 \
--max_model_len=4096
--hbm_utilization_factor=0.45
Request Cloud TPU quota
us-west1
region. This quota applies to one-click deployments and Colab Enterprise notebook deployments. To request a higher quota, see Request a quota adjustment.
Serve open models using Hex-LLM premium container on Cloud TPU
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2025-08-21 UTC.