v5e Cloud TPU VM에서 JetStream MaxText 추론


JetStream은 XLA 기기(TPU)에서 대규모 언어 모델(LLM) 추론을 위한 처리량 및 메모리 최적화 엔진입니다.

시작하기 전에

TPU 리소스 관리의 단계에 따라 --accelerator-type에서 v5litepod-8로 설정하는 TPU VM을 만들고 TPU VM에 연결합니다.

JetStream 및 MaxText 설정

  1. JetStream 및 MaxText GitHub 저장소 다운로드

       git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
       git clone -b v0.2.2 https://github.com/google/JetStream.git
    
  2. MaxText 설정

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

모델 체크포인트 변환

JetStream MaxText 서버를 Gemma 또는 Llama2 모델과 함께 실행할 수 있습니다. 이 섹션에서는 이러한 모델의 다양한 크기로 JetStream MaxText 서버를 실행하는 방법을 설명합니다.

Gemma 모델 체크포인트 사용

  1. Kaggle에서 Gemma 체크포인트를 다운로드합니다.
  2. 체크포인트를 Cloud Storage 버킷에 복사합니다.

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    ${YOUR_CKPT_PATH}${CHKPT_BUCKET}의 값이 포함된 예시는 변환 스크립트를 참조하세요.

  3. Gemma 체크포인트를 MaxText와 호환되는 스캔되지 않은 체크포인트로 변환합니다.

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Llama2 모델 체크포인트 사용

  1. 오픈소스 커뮤니티에서 Llama2 체크포인트를 다운로드하거나 사용자가 만든 체크포인트를 사용하세요.

  2. 체크포인트를 Cloud Storage 버킷에 복사합니다.

       gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    ${YOUR_CKPT_PATH}${CHKPT_BUCKET}의 값이 포함된 예시는 변환 스크립트를 참조하세요.

  3. Llama2 체크포인트를 MaxText와 호환되는 스캔되지 않은 체크포인트로 변환합니다.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

JetStream MaxText 서버 실행

이 섹션에서는 MaxText 호환 체크포인트를 사용하여 MaxText 서버를 실행하는 방법을 설명합니다.

MaxText 서버용 환경 변수 구성

사용 중인 모델을 기준으로 다음 환경 변수를 내보냅니다. model_ckpt_conversion.sh의 출력에서 UNSCANNED_CKPT_PATH에 이 값을 사용합니다.

서버 플래그용 Gemma-7b 환경 변수 만들기

JetStream MaxText 서버 플래그를 구성합니다.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

서버 플래그용 Llama2-7b 환경 변수 만들기

JetStream MaxText 서버 플래그를 구성합니다.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

서버 플래그용 Llama2-13b 환경 변수 만들기

JetStream MaxText 서버 플래그를 구성합니다.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

JetStream MaxText 서버 시작

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

JetStream MaxText Server 플래그 설명

tokenizer_path
tokenizer의 경로입니다(모델과 일치해야 함).
load_parameters_path
특정 디렉터리에서 매개변수(최적화 도구 상태 없음)를 로드합니다.
per_device_batch_size
기기당 디코딩 배치 크기(TPU 칩 1개 = 기기 1개)
max_prefill_predict_length
자동 회귀 수행 시 미리 입력의 최대 길이
max_target_length
최대 시퀀스 길이
model_name
모델 이름
ici_fsdp_parallelism
FSDP 동시 로드의 샤드 수
ici_autoregressive_parallelism
: 자동 회귀 동시 로드의 샤드 수
ici_tensor_parallelism
텐서 동시 로드의 샤드 수
weight_dtype
가중치 데이터 유형(예: bfloat16)
scan_layers
스캔 레이어 불리언 플래그(추론을 위해 `false` 로 설정)

JetStream MaxText 서버에 테스트 요청 전송

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

출력은 다음과 비슷하게 표시됩니다.

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

JetStream MaxText 서버로 벤치마크 실행

최상의 벤치마크 결과를 얻으려면 가중치와 KV 캐시 모두에 대해 양자화(정확도 보장을 위해 AQT 훈련 또는 미세 조정된 체크포인트 사용)를 사용 설정합니다. 양자화를 사용 설정하려면 양자화 플래그를 설정합니다.

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Gemma-7b 벤치마킹

Gemma-7b를 벤치마킹하려면 다음 단계를 따르세요.

  1. ShareGPT 데이터 세트를 다운로드합니다.
  2. Gemma 7b를 실행할 때는 Gemma tokenizer(tokenizer.gemma)를 사용해야 합니다.
  3. 첫 번째 실행에서 --warmup-first 플래그를 추가하여 서버를 워밍업합니다.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

더 큰 Llama2 벤치마킹

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

삭제

이 가이드에서 사용된 리소스 비용이 Google Cloud 계정에 청구되지 않도록 하려면 리소스가 포함된 프로젝트를 삭제하거나 프로젝트를 유지하고 개별 리소스를 삭제하세요.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env