Inferensi PyTorch JetStream pada VM Cloud TPU v5e


JetStream adalah mesin yang dioptimalkan untuk throughput dan memori untuk inferensi model bahasa besar (LLM) di perangkat XLA (TPU).

Sebelum memulai

Ikuti langkah-langkah di Menyiapkan lingkungan Cloud TPU untuk membuat project Google Cloud , mengaktifkan TPU API, menginstal TPU CLI, dan meminta kuota TPU.

Ikuti langkah-langkah di Membuat Cloud TPU menggunakan CreateNode API untuk membuat setelan VM TPU --accelerator-type ke v5litepod-8.

Meng-clone repositori JetStream dan menginstal dependensi

  1. Menghubungkan ke VM TPU menggunakan SSH

    • Tetapkan ${TPU_NAME} ke nama TPU Anda.
    • Tetapkan ${PROJECT} ke project Google Cloud Anda
    • Tetapkan ${ZONE} ke zona Google Cloud tempat Anda membuat TPU
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. Mendapatkan kode jetstream-pytorch bash git clone https://github.com/google/jetstream-pytorch.git git checkout jetstream-v0.2.4

(opsional) Buat virtual env menggunakan venv atau conda, lalu aktifkan.

sudo apt install python3.10-venv
python -m venv venv
source venv/bin/activate
  1. Jalankan skrip penginstalan:
cd jetstream-pytorch
source install_everything.sh

Menjalankan jetstream pytorch

Mencantumkan model yang didukung

jpt list

Tindakan ini akan mencetak daftar model dan varian dukungan:

meta-llama/Llama-2-7b-chat-hf
meta-llama/Llama-2-7b-hf
meta-llama/Llama-2-13b-chat-hf
meta-llama/Llama-2-13b-hf
meta-llama/Llama-2-70b-hf
meta-llama/Llama-2-70b-chat-hf
meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3-8B-Instruct
meta-llama/Meta-Llama-3-70B
meta-llama/Meta-Llama-3-70B-Instruct
google/gemma-2b
google/gemma-2b-it
google/gemma-7b
google/gemma-7b-it
mistralai/Mixtral-8x7B-v0.1
mistralai/Mixtral-8x7B-Instruct-v0.1

Untuk menjalankan server jetstream-pytorch dengan satu model: bash jpt serve --model_id meta-llama/Llama-2-7b-chat-hf

Saat pertama kali Anda menjalankan model ini, perintah jpt serve akan mencoba mendownload bobot dari HuggingFace yang mengharuskan Anda melakukan autentikasi dengan HuggingFace.

Untuk mengautentikasi, jalankan huggingface-cli login untuk menetapkan token akses, atau teruskan token akses HuggingFace ke perintah jpt serve menggunakan flag --hf_token:

jpt serve --model_id meta-llama/Llama-2-7b-chat-hf --hf_token=...

Untuk mengetahui informasi selengkapnya tentang token akses HuggingFace, lihat Token Akses.

Untuk login menggunakan HuggingFace Hub, jalankan perintah berikut dan ikuti petunjuknya:

pip install -U "huggingface_hub[cli]"
huggingface-cli login

Setelah bobot didownload, Anda tidak perlu lagi menentukan flag --hf_token.

Untuk menjalankan model ini dengan kuantisasi int8, tambahkan --quantize_weights=1. Kuantifikasi akan dilakukan saat penerbangan saat beban berat dimuat.

Bobot yang didownload dari HuggingFace disimpan secara default di direktori yang disebut folder checkpoints di direktori tempat Anda menjalankan jpt. Anda juga dapat mengubah menentukan direktori menggunakan flag --working_dir.

Jika Anda ingin menggunakan checkpoint Anda sendiri, tempatkan checkpoint tersebut di dalam direktori checkpoints/<org>/<model>/hf_original (atau subdir yang sesuai di --working_dir). Misalnya, checkpoint Llama2-7b akan berada di checkpoints/meta-llama/Llama-2-7b-hf/hf_original/*.safetensors. Anda dapat mengganti file ini dengan bobot yang dimodifikasi dalam format HuggingFace.

Jalankan tolok ukur

Ubah ke folder deps/JetStream yang didownload saat Anda menjalankan install_everything.sh.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

Untuk informasi selengkapnya, lihat deps/JetStream/benchmarks/README.md.

Error umum

Jika Anda mendapatkan error Unexpected keyword argument 'device', coba langkah berikut:

  • Meng-uninstal dependensi jax dan jaxlib
  • Menginstal ulang menggunakan source install_everything.sh

Jika Anda mendapatkan error Out of memory, coba langkah berikut:

  • Menggunakan ukuran batch yang lebih kecil
  • Menggunakan kuantisasi

Pembersihan

Agar tidak perlu membayar biaya pada akun Google Cloud Anda untuk resource yang digunakan dalam tutorial ini, hapus project yang berisi resource tersebut, atau simpan project dan hapus setiap resource.

  1. Membersihkan repositori GitHub

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Membersihkan lingkungan virtual python

    rm -rf .env
    
  3. Menghapus resource TPU

    Untuk mengetahui informasi selengkapnya, lihat Menghapus resource TPU.