Antes de começar
Siga as etapas em Gerenciar recursos de TPU para
criar uma VM de TPU definindo --accelerator-type
como v5litepod-8
e se conectando à
VM de TPU.
Configurar o JetStream e o MaxText
Fazer o download do JetStream e do repositório MaxText no GitHub
git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git git clone -b v0.2.2 https://github.com/google/JetStream.git
Configurar o MaxText
# Create a python virtual environment sudo apt install python3.10-venv python -m venv .env source .env/bin/activate # Set up MaxText cd maxtext/ bash setup.sh
Converter checkpoints de modelos
É possível executar o servidor JetStream MaxText com modelos Gemma ou Llama2. Esta seção descreve como executar o servidor JetStream MaxText com vários tamanhos desses modelos.
Usar um checkpoint do modelo Gemma
- Faça o download de um ponto de verificação do Gemma no Kaggle.
Copiar o checkpoint para o bucket do Cloud Storage
# Set YOUR_CKPT_PATH to the path to the checkpoints # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
Para conferir um exemplo com valores de
${YOUR_CKPT_PATH}
e${CHKPT_BUCKET}
, consulte o script de conversão.Converta o checkpoint do Gemma em um checkpoint não verificado compatível com MaxText.
# For gemma-7b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
Usar um checkpoint do modelo Llama2
Faça o download de um ponto de verificação do Llama2 na comunidade de código aberto ou use um que você gerou.
Copie os pontos de verificação para o bucket do Cloud Storage.
gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
Para conferir um exemplo com valores de
${YOUR_CKPT_PATH}
e${CHKPT_BUCKET}
, consulte o script de conversão.Converta o ponto de verificação Llama2 em um ponto de verificação não verificado compatível com MaxText.
# For llama2-7b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} # For llama2-13b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
Executar o servidor JetStream MaxText
Esta seção descreve como executar o servidor MaxText usando um checkpoint compatível com MaxText.
Configurar as variáveis de ambiente para o servidor MaxText
Exporte as seguintes variáveis de ambiente com base no modelo que você está usando.
Use o valor de UNSCANNED_CKPT_PATH
da saída
model_ckpt_conversion.sh
.
Criação de variáveis de ambiente Gemma-7b para flags do servidor
Configure as flags do servidor MaxText do JetStream.
export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
Criação de variáveis de ambiente Llama2-7b para flags do servidor
Configure as flags do servidor MaxText do JetStream.
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
Criação de variáveis de ambiente Llama2-13b para flags do servidor
Configure as flags do servidor MaxText do JetStream.
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4
Iniciar o servidor JetStream MaxText
cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
Descrições de flags do servidor do JetStream MaxText
tokenizer_path
- O caminho para um tokenizer (precisa corresponder ao seu modelo).
load_parameters_path
- Carrega os parâmetros (sem estados do otimizador) de um diretório específico
per_device_batch_size
- tamanho do lote de decodificação por dispositivo (1 chip TPU = 1 dispositivo)
max_prefill_predict_length
- Comprimento máximo do preenchimento prévio ao fazer a regressão automática
max_target_length
- Comprimento máximo da sequência
model_name
- Nome do modelo
ici_fsdp_parallelism
- O número de fragmentos para paralelismo de FSDP
ici_autoregressive_parallelism
- O número de fragmentos para paralelismo autoregressivo
ici_tensor_parallelism
- O número de fragmentos para paralelismo de tensor
weight_dtype
- Tipo de dados de peso (por exemplo, bfloat16)
scan_layers
- Flag booleana de camadas de verificação (definida como "false" para inferência)
Enviar uma solicitação de teste para o servidor JetStream MaxText
cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2
A saída será semelhante a esta:
Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response: to be a fan
Executar comparativos com o servidor JetStream MaxText
Para ter os melhores resultados de comparação, ative a quantização (use pontos de verificação treinados ou ajustados finamente com AQT para garantir a precisão) para pesos e cache KV. Para ativar a quantização, defina as flags de quantização:
# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true
# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance.
export PER_DEVICE_BATCH_SIZE=12
cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
quantization=${QUANTIZATION} \
quantize_kvcache=${QUANTIZE_KVCACHE}
Comparação de mercado do Gemma-7b
Para comparar o Gemma-7b, faça o seguinte:
- Faça o download do conjunto de dados ShareGPT.
- Use o tokenizer Gemma (tokenizer.gemma) ao executar o Gemma 7b.
- Adicione a flag
--warmup-first
para a primeira execução para aquecer o servidor.
# Activate the env python virtual environment
cd ~
source .env/bin/activate
# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled
Comparativo de mercado do Llama2 maior
# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled
Limpar
Para evitar cobranças na sua conta do Google Cloud pelos recursos usados no tutorial, exclua o projeto que os contém ou mantenha o projeto e exclua os recursos individuais.
# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}
# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream
# Clean up the python virtual environment
rm -rf .env