Inférence JetStream MaxText sur une VM Cloud TPU v5e


JetStream est un moteur optimisé pour le débit et la mémoire pour l'inférence de grands modèles de langage (LLM) sur les appareils XLA (TPU).

Avant de commencer

Suivez la procédure décrite dans Gérer les ressources TPU pour créer une VM TPU en définissant --accelerator-type sur v5litepod-8, puis connectez-vous à la VM TPU.

Configurer JetStream et MaxText

  1. Télécharger JetStream et le dépôt GitHub de MaxText

       git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
       git clone -b v0.2.2 https://github.com/google/JetStream.git
    
  2. Configurer MaxText

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

Convertir les points de contrôle du modèle

Vous pouvez exécuter le serveur JetStream MaxText avec les modèles Gemma ou Llama2. Cette section explique comment exécuter le serveur JetStream MaxText avec différentes tailles de ces modèles.

Utiliser un point de contrôle de modèle Gemma

  1. Téléchargez un point de contrôle Gemma depuis Kaggle.
  2. Copiez le point de contrôle dans votre bucket Cloud Storage.

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Pour obtenir un exemple incluant des valeurs pour ${YOUR_CKPT_PATH} et ${CHKPT_BUCKET}, consultez le script de conversion.

  3. Convertissez le point de contrôle Gemma en point de contrôle non analysé compatible avec MaxText.

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Utiliser un point de contrôle du modèle Llama2

  1. Téléchargez un point de contrôle Llama2 auprès de la communauté Open Source ou utilisez-en un que vous avez généré.

  2. Copiez les points de contrôle dans votre bucket Cloud Storage.

       gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Pour obtenir un exemple incluant des valeurs pour ${YOUR_CKPT_PATH} et ${CHKPT_BUCKET}, consultez le script de conversion.

  3. Convertissez le point de contrôle Llama2 en point de contrôle non analysé compatible avec MaxText.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

Exécuter le serveur JetStream MaxText

Cette section explique comment exécuter le serveur MaxText à l'aide d'un point de contrôle compatible avec MaxText.

Configurer les variables d'environnement pour le serveur MaxText

Exportez les variables d'environnement suivantes en fonction du modèle que vous utilisez. Utilisez la valeur de UNSCANNED_CKPT_PATH à partir de la sortie model_ckpt_conversion.sh.

Créer des variables d'environnement Gemma-7b pour les indicateurs de serveur

Configurez les indicateurs du serveur JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Créer des variables d'environnement Llama2-7b pour les indicateurs de serveur

Configurez les indicateurs du serveur JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Créer des variables d'environnement Llama2-13b pour les indicateurs de serveur

Configurez les indicateurs du serveur JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

Démarrer le serveur JetStream MaxText

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

Descriptions des indicateurs du serveur JetStream MaxText

tokenizer_path
Chemin d'accès à un tokenizer (doit correspondre à votre modèle).
load_parameters_path
Charge les paramètres (sans états d'optimiseur) à partir d'un répertoire spécifique.
per_device_batch_size
taille du lot de décodage par appareil (1 puce TPU = 1 appareil)
max_prefill_predict_length
Longueur maximale du préremplissage lors de la régression automatique
max_target_length
Longueur maximale de la séquence
model_name
Nom du modèle
ici_fsdp_parallelism
Nombre de segments pour le parallélisme FSDP
ici_autoregressive_parallelism
Nombre de segments pour le parallélisme autorégressif
ici_tensor_parallelism
Nombre de segments pour le parallélisme des Tensors
weight_dtype
Type de données de pondération (par exemple, bfloat16)
scan_layers
Indicateur booléen de numérisation des calques (défini sur "false" pour l'inférence)
.

Envoyer une requête de test au serveur JetStream MaxText

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

Le résultat doit ressembler à ce qui suit :

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

Exécuter des benchmarks avec le serveur JetStream MaxText

Pour obtenir les meilleurs résultats de benchmark, activez la quantification (utilisez des points de contrôle entraînés ou affinés par AQT pour garantir la précision) à la fois pour les poids et le cache KV. Pour activer la quantification, définissez les indicateurs de quantification:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Benchmarking Gemma-7b

Pour comparer Gemma-7b, procédez comme suit:

  1. Téléchargez l'ensemble de données ShareGPT.
  2. Assurez-vous d'utiliser le tokenizer Gemma (tokenizer.gemma) lorsque vous exécutez Gemma 7b.
  3. Ajoutez l'option --warmup-first pour votre première exécution afin de préparer le serveur.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Benchmarking d'un Llama2 plus volumineux

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Effectuer un nettoyage

Pour éviter que les ressources utilisées lors de ce tutoriel soient facturées sur votre compte Google Cloud, supprimez le projet contenant les ressources, ou conservez le projet et supprimez les ressources individuelles.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env