JetStream MaxText Inference on v5e Cloud TPU VM


JetStream is a throughput and memory optimized engine for large language model (LLM) inference on XLA devices (TPUs).

Before you begin

Follow the steps in Manage TPU resources to create a TPU VM setting --accelerator-type to v5litepod-8, and connect to the TPU VM.

Set up JetStream and MaxText

  1. Download JetStream and the MaxText GitHub repository

       git clone -b jetstream-v0.2.0 https://github.com/google/maxtext.git
       git clone -b v0.2.0 https://github.com/google/JetStream.git
    
  2. Set up MaxText

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

Convert model checkpoints

You can run the JetStream MaxText Server with Gemma or Llama2 models. This section describes how to run the JetStream MaxText server with various sizes of these models.

Use a Gemma model checkpoint

  1. Download a Gemma checkpoint from Kaggle.
  2. Copy the checkpoint to your Cloud Storage bucket

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    

    For an example including values for ${YOUR_CKPT_PATH} and ${CHKPT_BUCKET}, see the conversion script.

  3. Convert the Gemma checkpoint into a MaxText compatible unscanned checkpoint.

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Use a Llama2 model checkpoint

  1. Download a Llama2 checkpoint from the open source community, or use one you have generated.

  2. Copy the checkpoints to your Cloud Storage bucket.

       gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    

    For an example including values for ${YOUR_CKPT_PATH} and ${CHKPT_BUCKET}, see the conversion script.

  3. Convert the Llama2 checkpoint into a MaxText compatible unscanned checkpoint.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

Run the JetStream MaxText server

This section describes how to run the MaxText server using a MaxText compatible checkpoint.

Configure environment variables for the MaxText server

Export the following environment variables based on the model you are using. Use the value for UNSCANNED_CKPT_PATH from the model_ckpt_conversion.sh output.

Create Gemma-7b environment variables for server flags

Configure the JetStream MaxText server flags.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

Create Llama2-7b environment variables for server flags

Configure the JetStream MaxText server flags.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=6

Create Llama2-13b environment variables for server flags

Configure the JetStream MaxText server flags.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=2

Start the JetStream MaxText server

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

JetStream MaxText Server flag descriptions

tokenizer_path
The path to a tokenizer (should match your model).
load_parameters_path
Loads the parameters (no optimizer states) from a specific directory
per_device_batch_size
decoding batch size per device (1 TPU chip = 1 device)
max_prefill_predict_length
Maximum length for the prefill when doing autoregression
max_target_length
Maximum sequence length
model_name
Model name
ici_fsdp_parallelism
The number of shards for FSDP parallelism
ici_autoregressive_parallelism
The number of shards for autoregressive parallelism
ici_tensor_parallelism
The number of shards for tensor parallelism
weight_dtype
Weight data type (for example bfloat16)
scan_layers
Scan layers boolean flag

Send a test request to the JetStream MaxText server

cd ~
python JetStream/jetstream/tools/requester.py

The output will be similar to the following:

Sending request to: dns:///[::1]:9000
Prompt: Today is a good day
Response:  to be a fan

Run benchmarks with JetStream MaxText server

To get the best benchmark results, enable quantization (use AQT trained or fine tuned checkpoints to ensure accuracy) for both weights and KV cache. To enable quantization, set the quantization flags:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Benchmarking Gemma-7b

To benchmark Gemma-7b do the following:

  1. Download the ShareGPT dataset.
  2. Make sure to use the Gemma tokenizer (tokenizer.gemma) when running Gemma 7b.
  3. Add --warmup-first flag for your 1st run to warm up the server.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer /home/$USER/maxtext/assets/tokenizer.gemma \ 
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-first true

Benchmarking larger Llama2

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-first true

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env