Inferensi MaxText JetStream pada VM Cloud TPU v5e


JetStream adalah mesin yang dioptimalkan untuk throughput dan memori untuk inferensi model bahasa besar (LLM) di perangkat XLA (TPU).

Sebelum memulai

Ikuti langkah-langkah di Mengelola resource TPU untuk membuat setelan VM TPU --accelerator-type ke v5litepod-8, dan terhubung ke VM TPU.

Menyiapkan JetStream dan MaxText

  1. Mendownload JetStream dan repositori GitHub MaxText

       git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
       git clone -b v0.2.2 https://github.com/google/JetStream.git
    
  2. Menyiapkan MaxText

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

Mengonversi checkpoint model

Anda dapat menjalankan Server JetStream MaxText dengan model Gemma atau Llama2. Bagian ini menjelaskan cara menjalankan server JetStream MaxText dengan berbagai ukuran model ini.

Menggunakan checkpoint model Gemma

  1. Download titik pemeriksaan Gemma dari Kaggle.
  2. Menyalin titik pemeriksaan ke bucket Cloud Storage Anda

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Untuk contoh yang menyertakan nilai untuk ${YOUR_CKPT_PATH} dan ${CHKPT_BUCKET}, lihat skrip konversi.

  3. Mengonversi checkpoint Gemma menjadi checkpoint yang tidak dipindai dan kompatibel dengan MaxText.

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Menggunakan checkpoint model Llama2

  1. Download checkpoint Llama2 dari komunitas open source, atau gunakan checkpoint yang telah Anda buat.

  2. Salin titik pemeriksaan ke bucket Cloud Storage Anda.

       gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Untuk contoh yang menyertakan nilai untuk ${YOUR_CKPT_PATH} dan ${CHKPT_BUCKET}, lihat skrip konversi.

  3. Konversikan checkpoint Llama2 menjadi checkpoint yang tidak dipindai dan kompatibel dengan MaxText.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

Menjalankan server JetStream MaxText

Bagian ini menjelaskan cara menjalankan server MaxText menggunakan checkpoint yang kompatibel dengan MaxText.

Mengonfigurasi variabel lingkungan untuk server MaxText

Ekspor variabel lingkungan berikut berdasarkan model yang Anda gunakan. Gunakan nilai untuk UNSCANNED_CKPT_PATH dari output model_ckpt_conversion.sh.

Membuat variabel lingkungan Gemma-7b untuk flag server

Konfigurasikan flag server JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Membuat variabel lingkungan Llama2-7b untuk flag server

Konfigurasikan flag server JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Membuat variabel lingkungan Llama2-13b untuk flag server

Konfigurasikan flag server JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

Memulai server JetStream MaxText

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

Deskripsi flag Server MaxText JetStream

tokenizer_path
Jalur ke tokenizer (harus cocok dengan model Anda).
load_parameters_path
Memuat parameter (tanpa status pengoptimal) dari direktori tertentu
per_device_batch_size
ukuran batch decoding per perangkat (1 chip TPU = 1 perangkat)
max_prefill_predict_length
Panjang maksimum untuk praisi saat melakukan autoregresi
max_target_length
Panjang urutan maksimum
model_name
Nama model
ici_fsdp_parallelism
Jumlah shard untuk paralelisme FSDP
ici_autoregressive_parallelism
Jumlah shard untuk paralelisme autoregresif
ici_tensor_parallelism
Jumlah shard untuk paralelisme tensor
weight_dtype
Jenis data bobot (misalnya bfloat16)
scan_layers
Flag boolean lapisan pemindaian (ditetapkan ke `false` untuk inferensi)

Mengirim permintaan pengujian ke server JetStream MaxText

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

Outputnya akan mirip dengan berikut ini:

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

Menjalankan benchmark dengan server JetStream MaxText

Untuk mendapatkan hasil benchmark terbaik, aktifkan kuantisasi (gunakan checkpoint yang dilatih AQT atau yang telah disesuaikan dengan baik untuk memastikan akurasi) untuk bobot dan cache KV. Untuk mengaktifkan kuantisasi, tetapkan flag kuantisasi:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Melakukan benchmark Gemma-7b

Untuk menjalankan benchmark Gemma-7b, lakukan hal berikut:

  1. Download set data ShareGPT.
  2. Pastikan untuk menggunakan tokenizer Gemma (tokenizer.gemma) saat menjalankan Gemma 7b.
  3. Tambahkan flag --warmup-first untuk menjalankan pertama kalinya guna memanaskan server.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Melakukan benchmark Llama2 yang lebih besar

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Pembersihan

Agar akun Google Cloud Anda tidak ditagih atas resource yang digunakan dalam tutorial ini, hapus project yang berisi resource tersebut, atau simpan project dan hapus masing-masing resource.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env