Inferensi PyTorch JetStream pada VM Cloud TPU v5e


JetStream adalah mesin yang dioptimalkan untuk throughput dan memori untuk inferensi model bahasa besar (LLM) di perangkat XLA (TPU).

Sebelum memulai

Ikuti langkah-langkah di Menyiapkan lingkungan Cloud TPU untuk membuat project Google Cloud, mengaktifkan TPU API, menginstal TPU CLI, dan meminta kuota TPU.

Ikuti langkah-langkah di Membuat Cloud TPU menggunakan CreateNode API untuk membuat setelan VM TPU --accelerator-type ke v5litepod-8.

Meng-clone repositori JetStream dan menginstal dependensi

  1. Menghubungkan ke VM TPU menggunakan SSH

    • Tetapkan ${TPU_NAME} ke nama TPU Anda.
    • Menetapkan ${PROJECT} ke project Google Cloud Anda
    • Tetapkan ${ZONE} ke zona Google Cloud tempat Anda akan membuat TPU
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. Meng-clone repositori JetStream

       git clone https://github.com/google/jetstream-pytorch.git
    

    (Opsional) Buat lingkungan virtual Python menggunakan venv atau conda dan aktifkan.

  3. Menjalankan skrip penginstalan

       cd jetstream-pytorch
       source install_everything.sh
    

Mendownload dan mengonversi bobot

  1. Download bobot Llama resmi dari GitHub.

  2. Konversikan bobot.

    • Tetapkan ${IN_CKPOINT} ke lokasi yang berisi bobot Llama
    • Menetapkan ${OUT_CKPOINT} ke checkpoint tulis lokasi
    export input_ckpt_dir=${IN_CKPOINT} 
    export output_ckpt_dir=${OUT_CKPOINT} 
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

Menjalankan mesin PyTorch JetStream secara lokal

Untuk menjalankan mesin JetStream PyTorch secara lokal, tetapkan jalur tokenizer:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Menjalankan mesin PyTorch JetStream dengan Llama 7B

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Menjalankan mesin PyTorch JetStream dengan Llama 13b

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Menjalankan server JetStream

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

CATATAN: parameter --platform=tpu= harus menentukan jumlah perangkat TPU (yaitu 4 untuk v4-8 dan 8 untuk v5lite-8). Misalnya, --platform=tpu=8.

Setelah menjalankan run_server.py, mesin JetStream PyTorch siap menerima panggilan gRPC.

Jalankan tolok ukur

Ubah ke folder deps/JetStream yang didownload saat Anda menjalankan install_everything.sh.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

Untuk informasi selengkapnya, lihat deps/JetStream/benchmarks/README.md.

Error umum

Jika Anda mendapatkan error Unexpected keyword argument 'device', coba langkah berikut:

  • Meng-uninstal dependensi jax dan jaxlib
  • Menginstal ulang menggunakan source install_everything.sh

Jika Anda mendapatkan error Out of memory, coba langkah berikut:

  • Menggunakan ukuran batch yang lebih kecil
  • Menggunakan kuantisasi

Pembersihan

Agar tidak perlu membayar biaya pada akun Google Cloud Anda untuk resource yang digunakan dalam tutorial ini, hapus project yang berisi resource tersebut, atau simpan project dan hapus setiap resource.

  1. Membersihkan repositori GitHub

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Membersihkan lingkungan virtual python

    rm -rf .env
    
  3. Menghapus resource TPU

    Untuk informasi selengkapnya, lihat Menghapus resource TPU.