Before you begin
Follow the steps in Manage TPU resources to
create a TPU VM setting --accelerator-type
to v5litepod-8
, and connect to
the TPU VM.
Set up JetStream and MaxText
Download JetStream and the MaxText GitHub repository
git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git git clone -b v0.2.2 https://github.com/google/JetStream.git
Set up MaxText
# Create a python virtual environment sudo apt install python3.10-venv python -m venv .env source .env/bin/activate # Set up MaxText cd maxtext/ bash setup.sh
Convert model checkpoints
You can run the JetStream MaxText Server with Gemma or Llama2 models. This section describes how to run the JetStream MaxText server with various sizes of these models.
Use a Gemma model checkpoint
- Download a Gemma checkpoint from Kaggle.
Copy the checkpoint to your Cloud Storage bucket
# Set YOUR_CKPT_PATH to the path to the checkpoints # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
For an example including values for
${YOUR_CKPT_PATH}
and${CHKPT_BUCKET}
, see the conversion script.Convert the Gemma checkpoint into a MaxText compatible unscanned checkpoint.
# For gemma-7b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
Use a Llama2 model checkpoint
Download a Llama2 checkpoint from the open source community, or use one you have generated.
Copy the checkpoints to your Cloud Storage bucket.
gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
For an example including values for
${YOUR_CKPT_PATH}
and${CHKPT_BUCKET}
, see the conversion script.Convert the Llama2 checkpoint into a MaxText compatible unscanned checkpoint.
# For llama2-7b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} # For llama2-13b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
Run the JetStream MaxText server
This section describes how to run the MaxText server using a MaxText compatible checkpoint.
Configure environment variables for the MaxText server
Export the following environment variables based on the model you are using.
Use the value for UNSCANNED_CKPT_PATH
from the model_ckpt_conversion.sh
output.
Create Gemma-7b environment variables for server flags
Configure the JetStream MaxText server flags.
export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
Create Llama2-7b environment variables for server flags
Configure the JetStream MaxText server flags.
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
Create Llama2-13b environment variables for server flags
Configure the JetStream MaxText server flags.
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4
Start the JetStream MaxText server
cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
JetStream MaxText Server flag descriptions
tokenizer_path
- The path to a tokenizer (should match your model).
load_parameters_path
- Loads the parameters (no optimizer states) from a specific directory
per_device_batch_size
- decoding batch size per device (1 TPU chip = 1 device)
max_prefill_predict_length
- Maximum length for the prefill when doing autoregression
max_target_length
- Maximum sequence length
model_name
- Model name
ici_fsdp_parallelism
- The number of shards for FSDP parallelism
ici_autoregressive_parallelism
- The number of shards for autoregressive parallelism
ici_tensor_parallelism
- The number of shards for tensor parallelism
weight_dtype
- Weight data type (for example bfloat16)
scan_layers
- Scan layers boolean flag (set to `false` for inference)
Send a test request to the JetStream MaxText server
cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2
The output will be similar to the following:
Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response: to be a fan
Run benchmarks with JetStream MaxText server
To get the best benchmark results, enable quantization (use AQT trained or fine tuned checkpoints to ensure accuracy) for both weights and KV cache. To enable quantization, set the quantization flags:
# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true
# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance.
export PER_DEVICE_BATCH_SIZE=12
cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
quantization=${QUANTIZATION} \
quantize_kvcache=${QUANTIZE_KVCACHE}
Benchmarking Gemma-7b
To benchmark Gemma-7b do the following:
- Download the ShareGPT dataset.
- Make sure to use the Gemma tokenizer (tokenizer.gemma) when running Gemma 7b.
- Add
--warmup-first
flag for your 1st run to warm up the server.
# Activate the env python virtual environment
cd ~
source .env/bin/activate
# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled
Benchmarking larger Llama2
# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled
Clean up
To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.
# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}
# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream
# Clean up the python virtual environment
rm -rf .env