在 v5e Cloud TPU 虚拟机上进行 JetStream PyTorch 推理


JetStream 是一款针对 XLA 设备 (TPU) 上的大语言模型 (LLM) 推理进行了吞吐量和内存优化的引擎。

准备工作

按照设置 Cloud TPU 环境中的步骤操作 创建 Google Cloud 项目,激活 TPU API,安装 TPU CLI, TPU 配额。

按照使用 CreateNode API 创建 Cloud TPU 中的步骤操作 创建一个 TPU 虚拟机,将 --accelerator-type 设置为 v5litepod-8

克隆 JetStream 代码库并安装依赖项

  1. 使用 SSH 连接到您的 TPU 虚拟机

    • 将 ${TPU_NAME} 设置为 TPU 的名称。
    • 将 ${PROJECT} 设置为您的 Google Cloud 项目
    • 将 ${ZONE} 设置为要在其中创建 TPU 的 Google Cloud 可用区
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. 克隆 JetStream 代码库

       git clone https://github.com/google/jetstream-pytorch.git
    

    (可选)使用 venvconda 创建虚拟 Python 环境并将其激活。

  3. 运行安装脚本

       cd jetstream-pytorch
       source install_everything.sh
    

下载并转换权重

  1. GitHub 下载官方 Llama 权重。

  2. 转换权重。

    • 将 ${IN_CKPOINT} 设置为包含 Llama 权重的位置
    • 将 ${OUT_CKPOINT} 设置为位置写入检查点
    export input_ckpt_dir=${IN_CKPOINT} 
    export output_ckpt_dir=${OUT_CKPOINT} 
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

在本地运行 JetStream PyTorch 引擎

如需在本地运行 JetStream PyTorch 引擎,请设置分词器路径:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

使用 Llama 7B 运行 JetStream PyTorch 引擎

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

使用 Llama 13b 运行 JetStream PyTorch 引擎

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

运行 JetStream 服务器

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

注意:--platform=tpu= 参数需要指定 TPU 设备的数量(v4-8 为 4,v5lite-8 为 8)。例如,--platform=tpu=8

运行 run_server.py 后,JetStream PyTorch 引擎已准备好接收 gRPC 调用。

运行基准测试

切换到运行时下载的 deps/JetStream 文件夹 install_everything.sh

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

如需了解详情,请参阅 deps/JetStream/benchmarks/README.md

典型错误

如果您收到 Unexpected keyword argument 'device' 错误,请尝试以下操作:

  • 卸载 jaxjaxlib 依赖项
  • 使用 source install_everything.sh 重新安装

如果您收到 Out of memory 错误,请尝试以下操作:

  • 使用较小的批次大小
  • 使用量化

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

  1. 清理 GitHub 代码库

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. 清理 Python 虚拟环境

    rm -rf .env
    
  3. 删除 TPU 资源

    如需了解详情,请参阅删除 TPU 资源