Hinweise
Folgen Sie der Anleitung unter TPU-Ressourcen verwalten, um eine TPU-VM-Einstellung von --accelerator-type
zu v5litepod-8
zu erstellen und eine Verbindung zur TPU-VM herzustellen.
JetStream und MaxText einrichten
JetStream und das GitHub-Repository von MaxText herunterladen
git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git git clone -b v0.2.2 https://github.com/google/JetStream.git
MaxText einrichten
# Create a python virtual environment sudo apt install python3.10-venv python -m venv .env source .env/bin/activate # Set up MaxText cd maxtext/ bash setup.sh
Modellprüfpunkte konvertieren
Sie können den JetStream MaxText-Server mit Gemma- oder Llama2-Modellen ausführen. In diesem Abschnitt wird beschrieben, wie Sie den JetStream MaxText-Server mit verschiedenen Größen dieser Modelle ausführen.
Gemma-Modell-Prüfpunkt verwenden
- Laden Sie einen Gemma-Prüfpunkt von Kaggle herunter.
Prüfpunkt in den Cloud Storage-Bucket kopieren
# Set YOUR_CKPT_PATH to the path to the checkpoints # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
Ein Beispiel mit Werten für
${YOUR_CKPT_PATH}
und${CHKPT_BUCKET}
finden Sie im Conversion-Script.Wandeln Sie den Gemma-Prüfpunkt in einen MaxText-kompatiblen nicht gescannten Prüfpunkt um.
# For gemma-7b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
Llama2-Modellprüfpunkt verwenden
Laden Sie einen Llama2-Checkpoint aus der Open-Source-Community herunter oder verwenden Sie einen von Ihnen generierten.
Kopieren Sie die Checkpoints in Ihren Cloud Storage-Bucket.
gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
Ein Beispiel mit Werten für
${YOUR_CKPT_PATH}
und${CHKPT_BUCKET}
finden Sie im Conversion-Script.Wandeln Sie den Llama2-Prüfpunkt in einen MaxText-kompatiblen nicht gescannten Prüfpunkt um.
# For llama2-7b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} # For llama2-13b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
JetStream MaxText-Server ausführen
In diesem Abschnitt wird beschrieben, wie Sie den MaxText-Server mit einem MaxText-kompatiblen Prüfpunkt ausführen.
Umgebungsvariablen für den MaxText-Server konfigurieren
Exportieren Sie die folgenden Umgebungsvariablen entsprechend dem verwendeten Modell.
Verwenden Sie den Wert für UNSCANNED_CKPT_PATH
aus der Ausgabe für model_ckpt_conversion.sh
.
Gemma-7b-Umgebungsvariablen für Serverflaggen erstellen
Konfigurieren Sie die Flags des JetStream MaxText-Servers.
export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
Llama2-7b-Umgebungsvariablen für Serverflaggen erstellen
Konfigurieren Sie die Flags des JetStream MaxText-Servers.
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
Llama2-13b-Umgebungsvariablen für Serverflaggen erstellen
Konfigurieren Sie die Flags des JetStream MaxText-Servers.
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4
JetStream MaxText-Server starten
cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
Beschreibungen der JetStream MaxText-Server-Flags
tokenizer_path
- Der Pfad zu einem Tokenisierer (muss mit Ihrem Modell übereinstimmen).
load_parameters_path
- Ladet die Parameter (keine Optimiererstatus) aus einem bestimmten Verzeichnis
per_device_batch_size
- Batchgröße für die Decodierung pro Gerät (1 TPU-Chip = 1 Gerät)
max_prefill_predict_length
- Maximale Länge für das Vorausfüllen bei der automatischen Regression
max_target_length
- Maximale Sequenzlänge
model_name
- Modellname
ici_fsdp_parallelism
- Anzahl der Shards für die Parallelität der vollständig fragmentierten Daten
ici_autoregressive_parallelism
- Die Anzahl der Shards für die autoregressive Parallelität.
ici_tensor_parallelism
- Die Anzahl der Shards für die Tensor-Parallelität.
weight_dtype
- Datentyp für die Gewichtung (z. B. bfloat16)
scan_layers
- Boolesches Flag für Ebenen scannen (für die Inferenz auf „false“ setzen)
Testanfrage an den JetStream MaxText-Server senden
cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2
Die Ausgabe sollte in etwa so aussehen:
Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response: to be a fan
Benchmarks mit dem JetStream MaxText-Server ausführen
Um die besten Benchmark-Ergebnisse zu erzielen, aktivieren Sie die Quantisierung (verwenden Sie AQT-geschulte oder fein abgestimmte Checkpoints, um für Genauigkeit zu sorgen) sowohl für Gewichte als auch für den KV-Cache. Wenn Sie die Quantisierung aktivieren möchten, legen Sie die Quantisierungsflags fest:
# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true
# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance.
export PER_DEVICE_BATCH_SIZE=12
cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
quantization=${QUANTIZATION} \
quantize_kvcache=${QUANTIZE_KVCACHE}
Benchmarking Gemma-7b
So führen Sie einen Benchmark für Gemma-7b durch:
- Laden Sie das ShareGPT-Dataset herunter.
- Verwenden Sie beim Ausführen von Gemma 7b den Gemma-Tokenisierer (tokenizer.gemma).
- Fügen Sie für den ersten Durchlauf das Flag
--warmup-first
hinzu, um den Server zu wärmen.
# Activate the env python virtual environment
cd ~
source .env/bin/activate
# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled
Benchmark für größere Llama2-Modelle
# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled
Bereinigen
Damit Ihrem Google Cloud-Konto die in dieser Anleitung verwendeten Ressourcen nicht in Rechnung gestellt werden, löschen Sie entweder das Projekt, das die Ressourcen enthält, oder Sie behalten das Projekt und löschen die einzelnen Ressourcen.
# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}
# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream
# Clean up the python virtual environment
rm -rf .env