v5e Cloud TPU VM의 JetStream PyTorch 추론


JetStream은 XLA 기기(TPU)에서 대규모 언어 모델(LLM) 추론을 위한 처리량 및 메모리 최적화 엔진입니다.

시작하기 전에

Cloud TPU 환경 설정의 단계에 따라 Google Cloud 프로젝트를 만들고 TPU API를 활성화하고 TPU CLI를 설치하고 TPU 할당량을 요청합니다.

CreateNode API를 사용하여 Cloud TPU 만들기의 단계에 따라 --accelerator-typev5litepod-8로 설정하는 TPU VM을 만듭니다.

JetStream 저장소 클론 및 종속 항목 설치

  1. SSH를 사용하여 TPU VM에 연결하기

    • ${TPU_NAME}을 TPU 이름으로 설정합니다.
    • ${PROJECT}을 Google Cloud 프로젝트로 설정합니다.
    • TPU를 만들 Google Cloud 영역에 ${ZONE}을 설정합니다.
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. JetStream 저장소 클론

       git clone https://github.com/google/jetstream-pytorch.git
    

    (선택사항) venv 또는 conda를 사용하여 가상 Python 환경을 만들고 활성화합니다.

  3. 설치 스크립트를 실행합니다.

       cd jetstream-pytorch
       source install_everything.sh
    

가중치 다운로드 및 변환

  1. GitHub에서 공식 Llama 가중치를 다운로드하세요.

  2. 가중치를 변환합니다.

    • ${IN_CKPOINT}을 Llama 가중치가 포함된 위치로 설정합니다.
    • ${OUT_CKPOINT}을 위치 쓰기 체크포인트로 설정합니다.
    export input_ckpt_dir=${IN_CKPOINT}
    export output_ckpt_dir=${OUT_CKPOINT}
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

JetStream PyTorch 엔진을 로컬에서 실행

JetStream PyTorch 엔진을 로컬에서 실행하려면 tokenizer 경로를 설정합니다.

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Llama 7B로 JetStream PyTorch 엔진 실행

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Llama 13b로 JetStream PyTorch 엔진 실행

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

JetStream 서버 실행

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

참고: --platform=tpu= 매개변수는 TPU 기기(v4-8의 경우 4, v5lite-8의 경우 8)의 수를 지정해야 합니다. 예를 들면 --platform=tpu=8입니다.

run_server.py를 실행하면 JetStream PyTorch 엔진은 gRPC 호출을 수신할 준비가 됩니다.

벤치마크 실행

install_everything.sh를 실행할 때 다운로드한 deps/JetStream 폴더로 변경합니다.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

자세한 내용은 deps/JetStream/benchmarks/README.md를 참조하세요.

일반적인 오류

Unexpected keyword argument 'device' 오류가 발생하면 다음을 시도합니다.

  • jaxjaxlib 종속 항목 제거
  • source install_everything.sh를 사용하여 재설치

Out of memory 오류가 발생하면 다음을 시도합니다.

  • 더 작은 배치 크기 사용
  • 양자화 사용

삭제

이 튜토리얼에서 사용된 리소스 비용이 Google Cloud 계정에 청구되지 않도록 하려면 리소스가 포함된 프로젝트를 삭제하거나 프로젝트를 유지하고 개별 리소스를 삭제하세요.

  1. GitHub 저장소 삭제

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Python 가상 환경 정리

    rm -rf .env
    
  3. TPU 리소스 삭제

    자세한 내용은 TPU 리소스 삭제를 참조하세요.