Inferenza JetStream MaxText su VM Cloud TPU v5e


JetStream è un motore ottimizzato per la velocità effettiva e la memoria per l'inferenza dei modelli linguistici di grandi dimensioni (LLM) sui dispositivi XLA (TPU).

Prima di iniziare

Segui la procedura descritta in Gestire le risorse TPU per creare una VM TPU impostando --accelerator-type su v5litepod-8 e connettiti alla VM TPU.

Configurare JetStream e MaxText

  1. Scarica JetStream e il repository GitHub di MaxText

       git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
       git clone -b v0.2.2 https://github.com/google/JetStream.git
    
  2. Configurare MaxText

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

Converti i checkpoint del modello

Puoi eseguire JetStream MaxText Server con i modelli Gemma o Llama2. Questa sezione descrive come eseguire il server JetStream MaxText con vari modelli di dimensioni diverse.

Utilizzare un punto di controllo del modello Gemma

  1. Scarica un controllo Gemma da Kaggle.
  2. Copia il checkpoint nel tuo bucket Cloud Storage

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Per un esempio che include i valori di ${YOUR_CKPT_PATH} e ${CHKPT_BUCKET}, consulta lo script di conversione.

  3. Converti il checkpoint Gemma in un checkpoint non sottoposto a scansione compatibile con MaxText.

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Utilizzare un checkpoint del modello Llama2

  1. Scarica un checkpoint Llama2 dalla community open source o utilizzane uno generato da te.

  2. Copia i checkpoint nel bucket Cloud Storage.

       gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Per un esempio che include i valori di ${YOUR_CKPT_PATH} e ${CHKPT_BUCKET}, consulta lo script di conversione.

  3. Converti il checkpoint Llama2 in un checkpoint non sottoposto a scansione compatibile con MaxText.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

Esegui il server JetStream MaxText

Questa sezione descrive come eseguire il server MaxText utilizzando un checkpoint compatibile con MaxText.

Configurare le variabili di ambiente per il server MaxText

Esporta le seguenti variabili di ambiente in base al modello in uso. Utilizza il valore di UNSCANNED_CKPT_PATH dall'output di model_ckpt_conversion.sh.

Creare le variabili di ambiente Gemma-7b per i flag del server

Configura i flag del server JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Creare variabili di ambiente Llama2-7b per i flag del server

Configura i flag del server JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Creare variabili di ambiente Llama2-13b per i flag del server

Configura i flag del server JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

Avvia il server JetStream MaxText

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

Descrizioni dei flag del server JetStream MaxText

tokenizer_path
Il percorso di un tokenizzatore (deve corrispondere al tuo modello).
load_parameters_path
Carica i parametri (senza stati dell'ottimizzatore) da una directory specifica
per_device_batch_size
Dimensione del batch di decodifica per dispositivo (1 chip TPU = 1 dispositivo)
max_prefill_predict_length
Lunghezza massima per il precompletamento durante l'autoregressione
max_target_length
Lunghezza massima della sequenza
model_name
Nome modello
ici_fsdp_parallelism
Il numero di shard per il parallelismo FSDP
ici_autoregressive_parallelism
Numero di shard per il parallelismo autoregressivo
ici_tensor_parallelism
Il numero di shard per il parallelismo tensoriale
weight_dtype
Tipo di dati del peso (ad es. bfloat16)
scan_layers
Flag booleano per l'analisi dei livelli (impostato su "false" per l'inferenza)

Invia una richiesta di test al server JetStream MaxText

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

L'output sarà simile al seguente:

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

Eseguire benchmark con il server JetStream MaxText

Per ottenere i migliori risultati del benchmark, abilita la quantizzazione (utilizza i checkpoint addestrati o ottimizzati con AQT per garantire l'accuratezza) sia per i pesi che per la cache KV. Per attivare la quantizzazione, imposta i flag di quantizzazione:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Benchmarking di Gemma-7b

Per eseguire il benchmark di Gemma-7b:

  1. Scarica il set di dati ShareGPT.
  2. Assicurati di utilizzare il tokenizzatore Gemma (tokenizer.gemma) quando esegui Gemma 7b.
  3. Aggiungi il flag --warmup-first per la prima esecuzione per eseguire l'inizializzazione del server.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Benchmark di Llama2 di dimensioni maggiori

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Esegui la pulizia

Per evitare che al tuo account Google Cloud vengano addebitati costi relativi alle risorse utilizzate in questo tutorial, elimina il progetto che contiene le risorse oppure mantieni il progetto ed elimina le singole risorse.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env