v5e Cloud TPU VM での JetStream MaxText 推論


JetStream は、XLA デバイス(TPU)での大規模言語モデル(LLM)推論向けのスループットとメモリ最適化エンジンです。

始める前に

TPU リソースを管理するの手順に沿って、--accelerator-typev5litepod-8 に設定する TPU VM を作成し、TPU VM に接続します。

JetStream と MaxText を設定する

  1. JetStream と MaxText GitHub リポジトリをダウンロードする

       git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
       git clone -b v0.2.2 https://github.com/google/JetStream.git
    
  2. MaxText を設定する

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

モデルのチェックポイントを変換する

JetStream MaxText サーバーは、Gemma モデルまたは Llama2 モデルで実行できます。このセクションでは、さまざまなサイズのモデルで JetStream MaxText サーバーを実行する方法について説明します。

Gemma モデルのチェックポイントを使用する

  1. Kaggle から Gemma チェックポイントをダウンロードします。
  2. チェックポイントを Cloud Storage バケットにコピーします

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    

    ${YOUR_CKPT_PATH}${CHKPT_BUCKET} の値を含む例については、コンバージョン スクリプトをご覧ください。

  3. Gemma チェックポイントを MaxText 対応のスキャン済みチェックポイントに変換します。

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Llama2 モデル チェックポイントを使用する

  1. オープンソース コミュニティから Llama2 チェックポイントをダウンロードするか、生成したチェックポイントを使用します。

  2. チェックポイントを Cloud Storage バケットにコピーします。

       gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    

    ${YOUR_CKPT_PATH}${CHKPT_BUCKET} の値を含む例については、コンバージョン スクリプトをご覧ください。

  3. Llama2 チェックポイントを MaxText 対応のスキャン済みチェックポイントに変換します。

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

JetStream MaxText サーバーを実行する

このセクションでは、MaxText 対応のチェックポイントを使用して MaxText サーバーを実行する方法について説明します。

MaxText サーバー用に環境変数を構成する

使用しているモデルに基づいて、次の環境変数をエクスポートします。model_ckpt_conversion.sh 出力には、UNSCANNED_CKPT_PATH の値を使用します。

サーバーフラグの Gemma-7b 環境変数を作成する

JetStream MaxText サーバーのフラグを構成します。

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

サーバーフラグ用の Llama2-7b 環境変数を作成する

JetStream MaxText サーバーのフラグを構成します。

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

サーバーフラグ用の Llama2-13b 環境変数を作成する

JetStream MaxText サーバー フラグを構成します。

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

JetStream MaxText サーバーを起動する

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

JetStream MaxText Server フラグの説明

tokenizer_path
トークナイザのパス(モデルと一致している必要があります)。
load_parameters_path
特定のディレクトリからパラメータを読み込みます(オプティマイザーの状態なし)。
per_device_batch_size
デバイスごとのデコード バッチサイズ(1 つの TPU チップ = 1 台のデバイス)
max_prefill_predict_length
自動回帰を行う場合のプリフィルの最大長
max_target_length
シーケンスの最大長
model_name
モデル名
ici_fsdp_parallelism
FSDP 並列処理のシャードの数
ici_autoregressive_parallelism
自己回帰並列処理のシャードの数
ici_tensor_parallelism
テンソル並列処理のシャードの数
weight_dtype
重みデータ型(bfloat16 など)
scan_layers
スキャンレイヤのブール値フラグ(推論用に「false」に設定)

JetStream MaxText サーバーにテスト リクエストを送信する

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

出力は次のようになります。

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

JetStream MaxText サーバーでベンチマークを実行する

最適なベンチマーク結果を得るには、重みと KV キャッシュの両方で量子化を有効にします(AQT でトレーニングされたチェックポイントまたは微調整されたチェックポイントを使用して、精度を保証します)。量子化を有効にするには、量子化フラグを設定します。

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Gemma-7b のベンチマーク

Gemma-7b のベンチマークを行うには、次の手順を行います。

  1. ShareGPT データセットをダウンロードします。
  2. Gemma 7b を実行するときは、必ず Gemma トークナイザ(tokenizer.gemma)を使用してください。
  3. サーバーをウォームアップするために、最初の実行に --warmup-first フラグを追加します。
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

より大きな Llama2 のベンチマーク

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

クリーンアップ

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env