在 v5e Cloud TPU 虚拟机上进行 JetStream PyTorch 推断


JetStream 是用于在 XLA 设备 (TPU) 上的大型语言模型 (LLM) 推断的吞吐量和内存优化引擎。

准备工作

按照设置 Cloud TPU 环境中的步骤创建一个 Google Cloud 项目、激活 TPU API、安装 TPU CLI 并申请 TPU 配额。

按照使用 CreateNode API 创建 Cloud TPU 中的步骤创建一个 TPU 虚拟机,并将 --accelerator-type 设置为 v5litepod-8

克隆 JetStream 代码库并安装依赖项

  1. 使用 SSH 连接到您的 TPU 虚拟机

    • 将 ${TPU_NAME} 设置为您的 TPU 名称。
    • 将 ${PROJECT} 设为您的 Google Cloud 项目
    • 将 ${ZONE} 设置为要在其中创建 TPU 的 Google Cloud 可用区
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. 克隆 JetStream 代码库

       git clone https://github.com/google/jetstream-pytorch.git
    

    (可选)使用 venvconda 创建一个虚拟 Python 环境并将其激活。

  3. 运行安装脚本

       cd jetstream-pytorch
       source install_everything.sh
    

下载并转换权重

  1. GitHub 下载官方 Llama 权重。

  2. 转换权重。

    • 将 ${IN_CKPOINT} 设置为包含 Llama 权重的位置
    • 将 ${OUT_CKPOINT} 设置为位置写入检查点
    export input_ckpt_dir=${IN_CKPOINT}
    export output_ckpt_dir=${OUT_CKPOINT}
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

在本地运行 JetStream PyTorch 引擎

如需在本地运行 JetStream PyTorch 引擎,请设置标记生成器路径:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

使用 Llama 7B 运行 JetStream PyTorch 引擎

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

使用 Llama 13b 运行 JetStream PyTorch 引擎

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

运行 JetStream 服务器

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

注意:--platform=tpu= 参数需要指定 TPU 设备的数量(v4-8 为 4,v5lite-8 为 8)。例如 --platform=tpu=8

运行 run_server.py 后,JetStream PyTorch 引擎即可接收 gRPC 调用。

运行基准测试

切换到您运行 install_everything.sh 时下载的 deps/JetStream 文件夹。

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

如需了解详情,请参阅 deps/JetStream/benchmarks/README.md

典型错误

如果您收到 Unexpected keyword argument 'device' 错误,请尝试以下操作:

  • 卸载 jaxjaxlib 依赖项
  • 使用 source install_everything.sh 重新安装

如果您收到 Out of memory 错误,请尝试以下操作:

  • 使用较小的批次大小
  • 使用量化

清理

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

  1. 清理 GitHub 代码库

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. 清理 Python 虚拟环境

    rm -rf .env
    
  3. 删除 TPU 资源

    如需了解详情,请参阅删除 TPU 资源