Inferenza JetStream PyTorch su VM Cloud TPU v5e


JetStream è un motore ottimizzato per la velocità effettiva e la memoria per l'inferenza dei modelli linguistici di grandi dimensioni (LLM) sui dispositivi XLA (TPU).

Prima di iniziare

Segui i passaggi descritti in Configurare l'ambiente Cloud TPU per creare un progetto Google Cloud, attivare l'API TPU, installare TPU CLI e richiedere la quota TPU.

Segui i passaggi descritti in Creare una Cloud TPU utilizzando l'API CreateNode per creare una VM TPU impostando --accelerator-type su v5litepod-8.

Clona il repository JetStream e installa le dipendenze

  1. Connettiti alla VM TPU tramite SSH

    • Imposta ${TPU_NAME} sul nome della TPU.
    • Imposta ${PROJECT} sul tuo progetto Google Cloud
    • Imposta ${ZONE} sulla zona Google Cloud in cui creare le TPU
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. Clona il repository JetStream

       git clone https://github.com/google/jetstream-pytorch.git
    

    (Facoltativo) Crea un ambiente Python virtuale utilizzando venv o conda e attivalo.

  3. Esegui lo script di installazione

       cd jetstream-pytorch
       source install_everything.sh
    

Scaricare e convertire i pesi

  1. Scarica i pesi ufficiali di Llama da GitHub.

  2. Converti i pesi.

    • Imposta ${IN_CKPOINT} sulla posizione che contiene i pesi di Llama
    • Imposta ${OUT_CKPOINT} su un punto di controllo di scrittura della posizione
    export input_ckpt_dir=${IN_CKPOINT} 
    export output_ckpt_dir=${OUT_CKPOINT} 
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

Esegui il motore PyTorch di JetStream localmente

Per eseguire il motore PyTorch di JetStream in locale, imposta il percorso del tokenizzatore:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Esegui il motore PyTorch JetStream con Llama 7B

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Esegui il motore PyTorch JetStream con Llama 13b

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Esegui il server JetStream

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

NOTA: il parametro --platform=tpu= deve specificare il numero di dispositivi TPU (4 per v4-8 e 8 per v5lite-8). Ad esempio, --platform=tpu=8.

Dopo aver eseguito run_server.py, il motore PyTorch di JetStream è pronto per ricevere chiamate gRPC.

Eseguire benchmark

Vai alla cartella deps/JetStream scaricata quando hai eseguito install_everything.sh.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

Per ulteriori informazioni, consulta deps/JetStream/benchmarks/README.md.

Errori tipici

Se ricevi un errore Unexpected keyword argument 'device', prova a procedere nel seguente modo:

  • Disinstalla le dipendenze di jax e jaxlib
  • Reinstalla utilizzando source install_everything.sh

Se ricevi un errore Out of memory, prova a procedere nel seguente modo:

  • Utilizza dimensioni dei batch più piccole
  • Utilizza la quantizzazione

Esegui la pulizia

Per evitare che al tuo account Google Cloud vengano addebitati costi relativi alle risorse utilizzate in questo tutorial, elimina il progetto che contiene le risorse oppure mantieni il progetto ed elimina le singole risorse.

  1. Ripulire i repository GitHub

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Ripulisci l'ambiente virtuale Python

    rm -rf .env
    
  3. Elimina le risorse TPU

    Per ulteriori informazioni, consulta la sezione Eliminare le risorse TPU.