准备工作
按照管理 TPU 资源中的步骤创建 TPU 虚拟机,将 --accelerator-type
设置为 v5litepod-8
,然后连接到 TPU 虚拟机。
设置 JetStream 和 MaxText
下载 JetStream 和 MaxText GitHub 代码库
git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git git clone -b v0.2.2 https://github.com/google/JetStream.git
设置 MaxText
# Create a python virtual environment sudo apt install python3.10-venv python -m venv .env source .env/bin/activate # Set up MaxText cd maxtext/ bash setup.sh
转换模型检查点
您可以使用 Gemma 或 Llama2 模型运行 JetStream MaxText 服务器。本部分介绍了如何使用这些模型的不同大小运行 JetStream MaxText 服务器。
使用 Gemma 模型检查点
- 从 Kaggle 下载 Gemma 检查点。
将检查点复制到 Cloud Storage 存储桶
# Set YOUR_CKPT_PATH to the path to the checkpoints # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
如需查看包含
${YOUR_CKPT_PATH}
和${CHKPT_BUCKET}
值的示例,请参阅转化脚本。将 Gemma 检查点转换为与 MaxText 兼容的未扫描检查点。
# For gemma-7b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
使用 Llama2 模型检查点
从开源社区下载 Llama2 检查点,或使用您自己生成的检查点。
将检查点复制到您的 Cloud Storage 存储桶。
gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
如需查看包含
${YOUR_CKPT_PATH}
和${CHKPT_BUCKET}
值的示例,请参阅转化脚本。将 Llama2 检查点转换为与 MaxText 兼容的未扫描检查点。
# For llama2-7b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} # For llama2-13b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
运行 JetStream MaxText 服务器
本部分介绍了如何使用与 MaxText 兼容的检查点运行 MaxText 服务器。
为 MaxText 服务器配置环境变量
根据您使用的模型导出以下环境变量。
使用 model_ckpt_conversion.sh
输出中的 UNSCANNED_CKPT_PATH
值。
为服务器标志创建 Gemma-7b 环境变量
配置 JetStream MaxText 服务器标志。
export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
为服务器标志创建 Llama2-7b 环境变量
配置 JetStream MaxText 服务器标志。
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
为服务器标志创建 Llama2-13b 环境变量
配置 JetStream MaxText 服务器标志。
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4
启动 JetStream MaxText 服务器
cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
JetStream MaxText 服务器标志说明
tokenizer_path
- 分词器的路径(应与您的模型一致)。
load_parameters_path
- 从特定目录加载参数(无优化器状态)
per_device_batch_size
- 每个设备的解码批次大小(1 个 TPU 芯片 = 1 个设备)
max_prefill_predict_length
- 进行自动回归时预填充的最大长度
max_target_length
- 序列长度上限
model_name
- 模型名称
ici_fsdp_parallelism
- 用于 FSDP 并行的分片数
ici_autoregressive_parallelism
- 用于自动回归并行的分片数
ici_tensor_parallelism
- 用于张量并行的分片数
weight_dtype
- 权重数据类型(例如 bfloat16)
scan_layers
- Scan layers 布尔标志(设置为 `false` 以进行推理)
向 JetStream MaxText 服务器发送测试请求
cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2
输出将如下所示:
Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response: to be a fan
使用 JetStream MaxText 服务器运行基准测试
为了获得最佳基准测试结果,请为权重和 KV 缓存启用量化(使用 AQT 训练或微调的检查点以确保准确性)。如需启用量化,请设置量化标志:
# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true
# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance.
export PER_DEVICE_BATCH_SIZE=12
cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
quantization=${QUANTIZATION} \
quantize_kvcache=${QUANTIZE_KVCACHE}
对 Gemma-7b 进行基准测试
如需对 Gemma-7b 进行基准测试,请执行以下操作:
- 下载 ShareGPT 数据集。
- 运行 Gemma 7b 时,请务必使用 Gemma 分词器 (tokenizer.gemma)。
- 为首次运行添加
--warmup-first
标志,以预热服务器。
# Activate the env python virtual environment
cd ~
source .env/bin/activate
# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled
对较大的 Llama2 进行基准测试
# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled
清理
为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。
# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}
# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream
# Clean up the python virtual environment
rm -rf .env