Inferensi PyTorch JetStream pada VM Cloud TPU v5e


JetStream adalah mesin throughput dan memori yang dioptimalkan untuk inferensi model bahasa besar (LLM) pada perangkat XLA (TPU).

Sebelum memulai

Ikuti langkah-langkah di bagian Menyiapkan lingkungan Cloud TPU untuk membuat project Google Cloud, mengaktifkan TPU API, menginstal TPU CLI, dan meminta kuota TPU.

Ikuti langkah-langkah di bagian Membuat Cloud TPU menggunakan CreateNode API untuk membuat setelan VM TPU --accelerator-type ke v5litepod-8.

Meng-clone repositori JetStream dan menginstal dependensi

  1. Menghubungkan ke VM TPU menggunakan SSH

    • Tetapkan ${TPU_NAME} ke nama TPU Anda.
    • Tetapkan ${PROJECT} ke project Google Cloud Anda
    • Tetapkan ${ZONE} ke zona Google Cloud untuk membuat TPU Anda
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. Meng-clone repositori JetStream

       git clone https://github.com/google/jetstream-pytorch.git
    

    (Opsional) Buat lingkungan Python virtual menggunakan venv atau conda, lalu aktifkan.

  3. Menjalankan skrip penginstalan

       cd jetstream-pytorch
       source install_everything.sh
    

Download dan konversi berat

  1. Download bobot Llama resmi dari GitHub.

  2. Mengonversi bobot.

    • Tetapkan ${IN_CKPOINT} ke lokasi yang berisi bobot Llama
    • Tetapkan ${OUT_CKPOINT} ke checkpoint tulis lokasi
    export input_ckpt_dir=${IN_CKPOINT}
    export output_ckpt_dir=${OUT_CKPOINT}
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

Menjalankan mesin JetStream PyTorch secara lokal

Untuk menjalankan mesin JetStream PyTorch secara lokal, setel jalur tokenizer:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Jalankan mesin JetStream PyTorch dengan Llama 7B

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Jalankan mesin JetStream PyTorch dengan Llama 13b

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Menjalankan server JetStream

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

CATATAN: parameter --platform=tpu= harus menentukan jumlah perangkat TPU (yaitu 4 untuk v4-8 dan 8 untuk v5lite-8). Misalnya, --platform=tpu=8.

Setelah menjalankan run_server.py, mesin JetStream PyTorch siap menerima panggilan gRPC.

Menjalankan benchmark

Ubah ke folder deps/JetStream yang didownload saat Anda menjalankan install_everything.sh.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

Untuk mengetahui informasi selengkapnya, lihat deps/JetStream/benchmarks/README.md.

Error umum

Jika Anda mendapatkan error Unexpected keyword argument 'device', coba langkah berikut:

  • Meng-uninstal dependensi jax dan jaxlib
  • Instal ulang menggunakan source install_everything.sh

Jika Anda mendapatkan error Out of memory, coba langkah berikut:

  • Gunakan ukuran tumpukan yang lebih kecil
  • Menggunakan kuantisasi

Pembersihan

Agar tidak dikenakan biaya pada akun Google Cloud Anda untuk resource yang digunakan dalam tutorial ini, hapus project yang berisi resource tersebut, atau simpan project dan hapus setiap resource-nya.

  1. Membersihkan repositori GitHub

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Membersihkan lingkungan virtual Python

    rm -rf .env
    
  3. Menghapus resource TPU

    Untuk mengetahui informasi selengkapnya, lihat Menghapus resource TPU Anda.