Inferenza PyTorch di JetStream sulla VM Cloud TPU v5e


JetStream è un motore ottimizzato per velocità effettiva e memoria per l'inferenza dei modelli linguistici di grandi dimensioni (LLM) su dispositivi XLA (TPU).

Prima di iniziare

Segui i passaggi descritti in Configurare l'ambiente Cloud TPU per creare un progetto Google Cloud, attivare l'API TPU, installare l'interfaccia a riga di comando TPU e richiedere la quota TPU.

Segui i passaggi descritti in Creare una Cloud TPU utilizzando l'API CreateNode per creare un'impostazione VM TPU da --accelerator-type a v5litepod-8.

clona il repository JetStream e installa le dipendenze

  1. Connettiti alla VM TPU tramite SSH

    • Imposta ${TPU_NAME} con il nome della TPU.
    • Imposta ${PROJECT} sul progetto Google Cloud
    • Imposta ${ZONE} sulla zona Google Cloud in cui creare le TPU
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. clona il repository JetStream

       git clone https://github.com/google/jetstream-pytorch.git
    

    (Facoltativo) Crea un ambiente Python virtuale utilizzando venv o conda e attivalo.

  3. Esegui lo script di installazione

       cd jetstream-pytorch
       source install_everything.sh
    

Scarica e converti le ponderazioni

  1. Scarica i pesi Lama ufficiali da GitHub.

  2. Converti le ponderazioni.

    • Imposta ${IN_CKPOINT} sulla località che contiene i pesi lama
    • Imposta ${OUT_CKPOINT} su un checkpoint per la scrittura della località
    export input_ckpt_dir=${IN_CKPOINT}
    export output_ckpt_dir=${OUT_CKPOINT}
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

Eseguire il motore JetStream PyTorch in locale

Per eseguire il motore JetStream PyTorch localmente, imposta il percorso del tokenizzatore:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Esegui il motore JetStream PyTorch con Llama 7B

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Esegui il motore JetStream PyTorch con Llama 13b

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Esegui il server JetStream

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

NOTA: il parametro --platform=tpu= deve specificare il numero di dispositivi TPU (4 per v4-8 e 8 per v5lite-8). Ad esempio, --platform=tpu=8.

Dopo aver eseguito run_server.py, il motore JetStream PyTorch è pronto a ricevere chiamate gRPC.

Eseguire benchmark

Passa alla cartella deps/JetStream che è stata scaricata quando hai eseguito install_everything.sh.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

Per ulteriori informazioni, vedi deps/JetStream/benchmarks/README.md.

Errori tipici

Se viene visualizzato un errore Unexpected keyword argument 'device', prova a procedere nel seguente modo:

  • Disinstalla le dipendenze jax e jaxlib
  • Reinstalla utilizzando source install_everything.sh

Se viene visualizzato un errore Out of memory, prova a procedere nel seguente modo:

  • Usa batch di dimensioni minori
  • Utilizza la quantizzazione

Esegui la pulizia

Per evitare che al tuo Account Google Cloud vengano addebitati costi relativi alle risorse utilizzate in questo tutorial, elimina il progetto che contiene le risorse oppure mantieni il progetto ed elimina le singole risorse.

  1. Esegui la pulizia dei repository GitHub

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. esegui la pulizia dell'ambiente virtuale Python

    rm -rf .env
    
  3. Elimina le risorse TPU

    Per maggiori informazioni, vedi Eliminare le risorse TPU.