Inferência do JetStream MaxText na VM do Cloud TPU v5e


O JetStream é um mecanismo otimizado para capacidade de processamento e memória para inferência de modelos de linguagem grandes (LLMs) em dispositivos XLA (TPUs).

Antes de começar

Siga as etapas em Gerenciar recursos de TPU para criar uma VM de TPU definindo --accelerator-type como v5litepod-8 e se conectando à VM de TPU.

Configurar o JetStream e o MaxText

  1. Fazer o download do JetStream e do repositório MaxText no GitHub

       git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
       git clone -b v0.2.2 https://github.com/google/JetStream.git
    
  2. Configurar o MaxText

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

Converter checkpoints de modelos

É possível executar o servidor JetStream MaxText com modelos Gemma ou Llama2. Esta seção descreve como executar o servidor JetStream MaxText com vários tamanhos desses modelos.

Usar um checkpoint do modelo Gemma

  1. Faça o download de um ponto de verificação do Gemma no Kaggle.
  2. Copiar o checkpoint para o bucket do Cloud Storage

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Para conferir um exemplo com valores de ${YOUR_CKPT_PATH} e ${CHKPT_BUCKET}, consulte o script de conversão.

  3. Converta o checkpoint do Gemma em um checkpoint não verificado compatível com MaxText.

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Usar um checkpoint do modelo Llama2

  1. Faça o download de um ponto de verificação do Llama2 na comunidade de código aberto ou use um que você gerou.

  2. Copie os pontos de verificação para o bucket do Cloud Storage.

       gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Para conferir um exemplo com valores de ${YOUR_CKPT_PATH} e ${CHKPT_BUCKET}, consulte o script de conversão.

  3. Converta o ponto de verificação Llama2 em um ponto de verificação não verificado compatível com MaxText.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

Executar o servidor JetStream MaxText

Esta seção descreve como executar o servidor MaxText usando um checkpoint compatível com MaxText.

Configurar as variáveis de ambiente para o servidor MaxText

Exporte as seguintes variáveis de ambiente com base no modelo que você está usando. Use o valor de UNSCANNED_CKPT_PATH da saída model_ckpt_conversion.sh.

Criação de variáveis de ambiente Gemma-7b para flags do servidor

Configure as flags do servidor MaxText do JetStream.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Criação de variáveis de ambiente Llama2-7b para flags do servidor

Configure as flags do servidor MaxText do JetStream.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Criação de variáveis de ambiente Llama2-13b para flags do servidor

Configure as flags do servidor MaxText do JetStream.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

Iniciar o servidor JetStream MaxText

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

Descrições de flags do servidor do JetStream MaxText

tokenizer_path
O caminho para um tokenizer (precisa corresponder ao seu modelo).
load_parameters_path
Carrega os parâmetros (sem estados do otimizador) de um diretório específico
per_device_batch_size
tamanho do lote de decodificação por dispositivo (1 chip TPU = 1 dispositivo)
max_prefill_predict_length
Comprimento máximo do preenchimento prévio ao fazer a regressão automática
max_target_length
Comprimento máximo da sequência
model_name
Nome do modelo
ici_fsdp_parallelism
O número de fragmentos para paralelismo de FSDP
ici_autoregressive_parallelism
O número de fragmentos para paralelismo autoregressivo
ici_tensor_parallelism
O número de fragmentos para paralelismo de tensor
weight_dtype
Tipo de dados de peso (por exemplo, bfloat16)
scan_layers
Flag booleana de camadas de verificação (definida como "false" para inferência)

Enviar uma solicitação de teste para o servidor JetStream MaxText

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

A saída será semelhante a esta:

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

Executar comparativos com o servidor JetStream MaxText

Para ter os melhores resultados de comparação, ative a quantização (use pontos de verificação treinados ou ajustados finamente com AQT para garantir a precisão) para pesos e cache KV. Para ativar a quantização, defina as flags de quantização:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Comparação de mercado do Gemma-7b

Para comparar o Gemma-7b, faça o seguinte:

  1. Faça o download do conjunto de dados ShareGPT.
  2. Use o tokenizer Gemma (tokenizer.gemma) ao executar o Gemma 7b.
  3. Adicione a flag --warmup-first para a primeira execução para aquecer o servidor.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Comparativo de mercado do Llama2 maior

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Limpar

Para evitar cobranças na sua conta do Google Cloud pelos recursos usados no tutorial, exclua o projeto que os contém ou mantenha o projeto e exclua os recursos individuais.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env