Inferensi MaxText JetStream pada VM Cloud TPU v5e


JetStream adalah mesin throughput dan memori yang dioptimalkan untuk inferensi model bahasa besar (LLM) pada perangkat XLA (TPU).

Sebelum memulai

Ikuti langkah-langkah di Mengelola resource TPU untuk membuat setelan VM TPU --accelerator-type ke v5litepod-8, dan terhubung ke VM TPU.

Menyiapkan JetStream dan MaxText

  1. Download repositori JetStream dan MaxText GitHub

       git clone -b jetstream-v0.2.0 https://github.com/google/maxtext.git
       git clone -b v0.2.0 https://github.com/google/JetStream.git
    
  2. Menyiapkan MaxText

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

Mengonversi checkpoint model

Anda dapat menjalankan JetStream MaxText Server dengan model Gemma atau Llama2. Bagian ini menjelaskan cara menjalankan server JetStream MaxText dengan berbagai ukuran model ini.

Menggunakan checkpoint model Gemma

  1. Download checkpoint Gemma dari Kaggle.
  2. Menyalin checkpoint ke bucket Cloud Storage Anda

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    

    Untuk contoh yang menyertakan nilai untuk ${YOUR_CKPT_PATH} dan ${CHKPT_BUCKET}, lihat skrip konversi.

  3. Konversi checkpoint Gemma menjadi checkpoint yang tidak dipindai dan kompatibel dengan MaxText.

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Menggunakan checkpoint model Llama2

  1. Download checkpoint Llama2 dari komunitas open source, atau gunakan checkpoint yang telah Anda buat.

  2. Salin checkpoint ke bucket Cloud Storage Anda.

       gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    

    Untuk contoh yang menyertakan nilai untuk ${YOUR_CKPT_PATH} dan ${CHKPT_BUCKET}, lihat skrip konversi.

  3. Konversi checkpoint Llama2 menjadi checkpoint yang tidak dipindai dan kompatibel dengan MaxText.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

Menjalankan server JetStream MaxText

Bagian ini menjelaskan cara menjalankan server MaxText menggunakan checkpoint yang kompatibel dengan MaxText.

Mengonfigurasi variabel lingkungan untuk server MaxText

Ekspor variabel lingkungan berikut berdasarkan model yang Anda gunakan. Gunakan nilai untuk UNSCANNED_CKPT_PATH dari output model_ckpt_conversion.sh.

Membuat variabel lingkungan Gemma-7b untuk flag server

Konfigurasikan tanda server JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

Membuat variabel lingkungan Llama2-7b untuk flag server

Konfigurasikan tanda server JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=6

Membuat variabel lingkungan Llama2-13b untuk flag server

Konfigurasikan tanda server JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=2

Memulai server JetStream MaxText

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

Deskripsi tanda JetStream MaxText Server

tokenizer_path
Jalur ke tokenizer (harus cocok dengan model Anda).
load_parameters_path
Memuat parameter (tidak ada status pengoptimal) dari direktori tertentu
per_device_batch_size
ukuran batch decoding per perangkat (1 chip TPU = 1 perangkat)
max_prefill_predict_length
Panjang maksimum untuk pengisian otomatis saat melakukan autoregresi
max_target_length
Panjang urutan maksimum
model_name
Nama model
ici_fsdp_parallelism
Jumlah shard untuk paralelisme FSDP
ici_autoregressive_parallelism
Jumlah shard untuk paralelisme autoregresif
ici_tensor_parallelism
Jumlah shard untuk paralelisme tensor
weight_dtype
Jenis data berat (misalnya bfloat16)
scan_layers
Flag boolean pemindaian lapisan

Mengirim permintaan pengujian ke server JetStream MaxText

cd ~
python JetStream/jetstream/tools/requester.py

Outputnya akan mirip dengan berikut ini:

Sending request to: dns:///[::1]:9000
Prompt: Today is a good day
Response:  to be a fan

Menjalankan benchmark dengan server JetStream MaxText

Untuk mendapatkan hasil tolok ukur terbaik, aktifkan kuantisasi (gunakan checkpoint yang dilatih atau disesuaikan oleh AQT untuk memastikan akurasi) untuk bobot dan cache KV. Untuk mengaktifkan kuantisasi, tetapkan flag kuantisasi:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance.
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Membuat tolok ukur Gemma-7b

Untuk menjalankan benchmark Gemma-7b, lakukan hal berikut:

  1. Unduh {i>dataset <i}ShareGPT.
  2. Pastikan untuk menggunakan tokenizer Gemma (tokenizer.gemma) saat menjalankan Gemma 7b.
  3. Tambahkan flag --warmup-first untuk proses pertama Anda guna menyiapkan server.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer /home/$USER/maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-first true

Benchmark dengan Llama2 yang lebih besar

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-first true

Pembersihan

Agar akun Google Cloud Anda tidak ditagih atas resource yang digunakan dalam tutorial ini, hapus project yang berisi resource tersebut, atau simpan project dan hapus masing-masing resource.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env