v5e Cloud TPU VM에서 JetStream PyTorch 추론


JetStream은 XLA 기기(TPU)에서 대규모 언어 모델(LLM) 추론을 위한 처리량 및 메모리 최적화 엔진입니다.

시작하기 전에

Cloud TPU 환경 설정의 단계에 따라 Google Cloud 프로젝트를 만들고, TPU API를 활성화하고, TPU CLI를 설치한 후 TPU 할당량을 요청합니다.

CreateNode API를 사용하여 Cloud TPU 만들기의 단계에 따라 --accelerator-typev5litepod-8로 설정하는 TPU VM을 만듭니다.

JetStream 저장소 클론 및 종속 항목 설치

  1. SSH를 사용하여 TPU VM에 연결하기

    • ${TPU_NAME}을 TPU 이름으로 설정합니다.
    • ${PROJECT}을 Google Cloud 프로젝트로 설정합니다.
    • TPU를 만들 Google Cloud 영역에 ${ZONE}을 설정합니다.
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. JetStream 저장소 클론

       git clone https://github.com/google/jetstream-pytorch.git
    

    (선택사항) venv 또는 conda를 사용하여 가상 Python 환경을 만들고 활성화합니다.

  3. 설치 스크립트를 실행합니다.

       cd jetstream-pytorch
       source install_everything.sh
    

가중치 다운로드 및 변환

  1. GitHub에서 공식 Llama 가중치를 다운로드합니다.

  2. 가중치를 변환합니다.

    • ${IN_CKPOINT}을 Llama 가중치가 포함된 위치로 설정합니다.
    • ${OUT_CKPOINT}을 위치 쓰기 체크포인트로 설정합니다.
    export input_ckpt_dir=${IN_CKPOINT} 
    export output_ckpt_dir=${OUT_CKPOINT} 
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

로컬에서 JetStream PyTorch 엔진 실행

JetStream PyTorch 엔진을 로컬에서 실행하려면 토크나이저 경로를 설정합니다.

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Llama 7B로 JetStream PyTorch 엔진 실행

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Llama 13b로 JetStream PyTorch 엔진 실행

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

JetStream 서버 실행

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

참고: --platform=tpu= 파라미터는 TPU 기기(v4-8의 경우 4, v5lite-8의 경우 8)의 수를 지정해야 합니다. 예를 들면 --platform=tpu=8입니다.

run_server.py를 실행한 후에는 JetStream PyTorch 엔진이 gRPC 호출을 수신할 준비가 됩니다.

벤치마크 실행

install_everything.sh를 실행할 때 다운로드된 deps/JetStream 폴더로 변경합니다.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

자세한 내용은 deps/JetStream/benchmarks/README.md를 참조하세요.

일반적인 오류

Unexpected keyword argument 'device' 오류가 발생하면 다음을 시도해 보세요.

  • jaxjaxlib 종속 항목 제거
  • source install_everything.sh를 사용하여 재설치

Out of memory 오류가 발생하면 다음을 시도해 보세요.

  • 더 작은 배치 크기 사용
  • 양자화 사용

삭제

이 튜토리얼에서 사용된 리소스 비용이 Google Cloud 계정에 청구되지 않도록 하려면 리소스가 포함된 프로젝트를 삭제하거나 프로젝트를 유지하고 개별 리소스를 삭제하세요.

  1. GitHub 저장소 삭제

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Python 가상 환경 삭제

    rm -rf .env
    
  3. TPU 리소스 삭제

    자세한 내용은 TPU 리소스 삭제를 참고하세요.