Inferência do JetStream MaxText na VM do Cloud TPU v5e


O JetStream é um mecanismo otimizado para capacidade de processamento e memória para modelos de linguagem grandes. (LLM) em dispositivos XLA (TPUs).

Antes de começar

Siga as etapas em Gerenciar recursos de TPU para criar uma configuração de VM de TPU --accelerator-type para v5litepod-8 e conectar-se a VM da TPU.

Configurar o JetStream e o MaxText

  1. Fazer o download do JetStream e do repositório MaxText no GitHub

       git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
       git clone -b v0.2.2 https://github.com/google/JetStream.git
    
  2. Configurar MaxText

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

Converter pontos de verificação do modelo

É possível executar o servidor JetStream MaxText com modelos Gemma ou Llama2. Isso descreve como executar o servidor JetStream MaxText com vários tamanhos de esses modelos.

Usar um checkpoint do modelo Gemma

  1. Faça o download de um checkpoint Gemma do Kaggle.
  2. Copie o checkpoint para o bucket do Cloud Storage.

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Para ver um exemplo que inclui valores para ${YOUR_CKPT_PATH} e ${CHKPT_BUCKET}, consulte o script de conversão.

  3. Converta o checkpoint Gemma em um checkpoint não verificado compatível com o MaxText.

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Usar um checkpoint do modelo Llama2

  1. Faça o download de um checkpoint Llama2 na comunidade de código aberto. ou use uma que você gerou.

  2. Copie os pontos de verificação para o bucket do Cloud Storage.

       gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Para conferir um exemplo com valores de ${YOUR_CKPT_PATH} e ${CHKPT_BUCKET}, consulte o script de conversão.

  3. Converta o checkpoint Llama2 em um checkpoint não verificado compatível com o MaxText.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

Executar o servidor JetStream MaxText

Esta seção descreve como executar o servidor MaxText usando um checkpoint compatível com MaxText.

Configurar variáveis de ambiente para o servidor MaxText

Exporte as seguintes variáveis de ambiente com base no modelo que você está usando. Use o valor de UNSCANNED_CKPT_PATH do model_ckpt_conversion.sh saída.

Criação de variáveis de ambiente Gemma-7b para flags do servidor

Configure as flags do servidor do JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Criação de variáveis de ambiente Llama2-7b para flags do servidor

Configure as flags do servidor do JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Criação de variáveis de ambiente Llama2-13b para flags do servidor

Configure as flags do servidor do JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

Iniciar o servidor JetStream MaxText

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

Descrições de flags do servidor JetStream MaxText

tokenizer_path
O caminho para um tokenizador (precisa corresponder ao seu modelo).
load_parameters_path
Carrega os parâmetros (sem estados do otimizador) de um diretório específico
per_device_batch_size
tamanho do lote de decodificação por dispositivo (1 chip TPU = 1 dispositivo)
max_prefill_predict_length
Tamanho máximo do preenchimento automático ao fazer regressão
max_target_length
Comprimento máximo da sequência
model_name
Nome do modelo
ici_fsdp_parallelism
O número de fragmentos para o paralelismo do FSDP
ici_autoregressive_parallelism
O número de fragmentos para paralelismo autoregressivo
ici_tensor_parallelism
O número de fragmentos para paralelismo de tensor
weight_dtype
Tipo de dados de peso (por exemplo, bfloat16)
scan_layers
Flag booleana de camadas (definida como "false" para inferência)
.

Enviar uma solicitação de teste para o servidor JetStream MaxText

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

A saída será semelhante a esta:

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

Executar comparativos com o servidor MaxText do JetStream

Para ter os melhores resultados de comparativo de mercado, ative a quantização (use o AQT treinado ou pontos de verificação ajustados para garantir a precisão) para pesos e cache de KV. Para ativar a quantização, defina as flags de quantização:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Como fazer comparações com o Gemma-7b

Para comparar o Gemma-7b, faça o seguinte:

  1. Faça o download do conjunto de dados do ShareGPT.
  2. Use o tokenizer Gemma (tokenizer.gemma) ao executar o Gemma 7b.
  3. Adicione a flag --warmup-first para a primeira execução para aquecer o servidor.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Comparativo de mercado de Llama2 maior

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Limpar

Para evitar cobranças na sua conta do Google Cloud pelos recursos usados no tutorial, exclua o projeto que os contém ou mantenha o projeto e exclua os recursos individuais.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env