JetStream MaxText-Inferenz auf v5e Cloud TPU-VM


JetStream ist eine durchsatz- und speicheroptimierte Engine für ein Large Language Model LLM-Inferenz auf XLA-Geräte (TPUs).

Hinweise

Führen Sie die Schritte unter TPU-Ressourcen verwalten aus, um Erstellen Sie eine TPU-VM-Einstellung --accelerator-type auf v5litepod-8 und stellen Sie eine Verbindung zu der TPU-VM.

JetStream und MaxText einrichten

  1. GitHub-Repository für JetStream und MaxText herunterladen

       git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
       git clone -b v0.2.2 https://github.com/google/JetStream.git
    
  2. MaxText einrichten

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

Modellprüfpunkte konvertieren

Sie können den JetStream MaxText Server mit Gemma- oder Llama2-Modellen ausführen. Dieses wird beschrieben, wie der JetStream MaxText-Server mit verschiedenen für diese Modelle.

Gemma-Modellprüfpunkt verwenden

  1. Laden Sie einen Gemma-Checkpoint von Kaggle herunter.
  2. Prüfpunkt in den Cloud Storage-Bucket kopieren

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    

    Ein Beispiel mit Werten für ${YOUR_CKPT_PATH} und ${CHKPT_BUCKET} finden Sie im Conversion-Skript.

  3. Konvertieren Sie den Gemma-Prüfpunkt in einen mit MaxText kompatiblen, nicht gescannten Prüfpunkt.

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Llama2-Modellprüfpunkt verwenden

  1. Laden Sie einen Llama2-Checkpoint aus der Open-Source-Community herunter. oder verwenden Sie eins, das Sie erstellt haben.

  2. Kopieren Sie die Prüfpunkte in Ihren Cloud Storage-Bucket.

       gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    

    Ein Beispiel mit Werten für ${YOUR_CKPT_PATH} und ${CHKPT_BUCKET}: finden Sie im Conversion-Skript.

  3. Konvertieren Sie den Llama2-Prüfpunkt in einen mit MaxText kompatiblen, nicht gescannten Prüfpunkt.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

JetStream MaxText-Server ausführen

In diesem Abschnitt wird beschrieben, wie Sie den MaxText-Server mit einem MaxText-kompatiblen Prüfpunkt festlegen.

Umgebungsvariablen für den MaxText-Server konfigurieren

Exportieren Sie die folgenden Umgebungsvariablen basierend auf dem von Ihnen verwendeten Modell. Verwenden Sie den Wert für UNSCANNED_CKPT_PATH aus model_ckpt_conversion.sh .

Gemma-7b-Umgebungsvariablen für Server-Flags erstellen

Konfigurieren Sie die Server-Flags für JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Llama2-7b-Umgebungsvariablen für Server-Flags erstellen

Konfigurieren Sie die Server-Flags für JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Llama2-13b-Umgebungsvariablen für Server-Flags erstellen

Konfigurieren Sie die Server-Flags für JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

JetStream MaxText-Server starten

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

Beschreibungen der JetStream MaxText Server-Flags

tokenizer_path
Der Pfad zu einem Tokenizer (sollte mit Ihrem Modell übereinstimmen).
load_parameters_path
Lädt die Parameter (keine Optimierungsstatus) aus einem bestimmten Verzeichnis
per_device_batch_size
Decodierungs-Batchgröße pro Gerät (1 TPU-Chip = 1 Gerät)
max_prefill_predict_length
Maximale Länge des Vorausfüllens bei automatischer Regression
max_target_length
Maximale Sequenzlänge
model_name
Modellname
ici_fsdp_parallelism
Die Anzahl der Shards für die FSDP-Parallelität
ici_autoregressive_parallelism
Die Anzahl der Shards für autoregressive Parallelität
ici_tensor_parallelism
Die Anzahl der Shards für die Tensor-Parallelität
weight_dtype
Datentyp „Gewichtung“ (z. B. bfloat16)
scan_layers
Boolesches Flag zum Scannen von Ebenen (für Inferenz auf „false“ gesetzt)

Testanfrage an den JetStream MaxText-Server senden

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

Die Ausgabe sollte in etwa so aussehen:

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

Benchmarks mit dem JetStream MaxText-Server ausführen

Die besten Benchmark-Ergebnisse erhalten Sie, wenn Sie die Quantisierung aktivieren (mit AQT-Training oder angepasste Checkpoints zur Gewährleistung der Genauigkeit) sowohl für Gewichtungen als auch für den KV-Cache. Zum Aktivieren Quantisierungs-Flags:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Benchmarking-Gemma-7b

Gehen Sie für das Benchmarking von Gemma-7b so vor:

  1. Laden Sie das ShareGPT-Dataset herunter.
  2. Achten Sie darauf, den Gemma-Tokenizer (tokenizer.gemma) zu verwenden, wenn Sie Gemma 7b ausführen.
  3. Fügen Sie das Flag --warmup-first für die erste Ausführung hinzu, um den Server aufzuwärmen.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Benchmarking für größeres Llama2

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Bereinigen

Damit Ihrem Google Cloud-Konto die in dieser Anleitung verwendeten Ressourcen nicht in Rechnung gestellt werden, löschen Sie entweder das Projekt, das die Ressourcen enthält, oder Sie behalten das Projekt und löschen die einzelnen Ressourcen.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env