JetStream PyTorch Inference on v5e Cloud TPU VM

JetStream is a throughput and memory optimized engine for large language model (LLM) inference on XLA devices (TPUs).

Before you begin

Follow the steps in Set up the Cloud TPU environment to create a Google Cloud project, activate the TPU API, install the TPU CLI, and request TPU quota.

Follow the steps in Create a Cloud TPU using the CreateNode API to create a TPU VM setting --accelerator-type to v5litepod-8.

Clone the JetStream repository and install dependencies

  1. Connect to your TPU VM using SSH

    • Set ${TPU_NAME} to your TPU's name.
    • Set ${PROJECT} to your Google Cloud project
    • Set ${ZONE} to the Google Cloud zone in which to create your TPUs
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
  2. Clone the JetStream repository

       git clone

    (Optional) Create a virtual Python environment using venv or conda and activate it.

  3. Run the installation script

       cd jetstream-pytorch

Download and convert weights

  1. Download the official Llama weights from GitHub.

  2. Convert the weights.

    • Set ${IN_CKPOINT} to the location that contains the Llama weights
    • Set ${OUT_CKPOINT} to a location write checkpoints
    export input_ckpt_dir=${IN_CKPOINT} 
    export output_ckpt_dir=${OUT_CKPOINT} 
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize

Run the JetStream PyTorch engine locally

To run the JetStream PyTorch engine locally, set the tokenizer path:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Run the JetStream PyTorch engine with Llama 7B

python --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Run the JetStream PyTorch engine with Llama 13b

python --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Run the JetStream server

python --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

NOTE: the --platform=tpu= parameter needs to specify number of TPU devices (which is 4 for v4-8 and 8 for v5lite-8). For example, --platform=tpu=8.

After running the JetStream PyTorch engine is ready to receive gRPC calls.

Run benchmarks

Change to the deps/JetStream folder that was downloaded when you ran

cd deps/JetStream
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/ --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

For more information see, deps/JetStream/benchmarks/

Typical errors

If you get an Unexpected keyword argument 'device' error, try the following:

  • Uninstall jax and jaxlib dependencies
  • Reinstall using source

If you get an Out of memory error, try the following:

  • Use smaller batch size
  • Use quantization

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

  1. Clean up the GitHub repositories

      # Clean up the JetStream repository
      rm -rf JetStream
      # Clean up the xla repository
      rm -rf xla
  2. Clean up the python virtual environment

    rm -rf .env
  3. Delete your TPU resources

    For more information, see Delete your TPU resources.