Prima di iniziare
Segui i passaggi descritti in Configurare l'ambiente Cloud TPU per creare un progetto Google Cloud, attivare l'API TPU, installare TPU CLI e richiedere la quota TPU.
Segui i passaggi descritti in Creare una Cloud TPU utilizzando l'API CreateNode per
creare una VM TPU impostando --accelerator-type
su v5litepod-8
.
Clona il repository JetStream e installa le dipendenze
Connettiti alla VM TPU tramite SSH
- Imposta ${TPU_NAME} sul nome della TPU.
- Imposta ${PROJECT} sul tuo progetto Google Cloud
- Imposta ${ZONE} sulla zona Google Cloud in cui creare le TPU
gcloud compute config-ssh gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
Clona il repository JetStream
git clone https://github.com/google/jetstream-pytorch.git
(Facoltativo) Crea un ambiente Python virtuale utilizzando
venv
oconda
e attivalo.Esegui lo script di installazione
cd jetstream-pytorch source install_everything.sh
Scaricare e convertire i pesi
Scarica i pesi ufficiali di Llama da GitHub.
Converti i pesi.
- Imposta ${IN_CKPOINT} sulla posizione che contiene i pesi di Llama
- Imposta ${OUT_CKPOINT} su un punto di controllo di scrittura della posizione
export input_ckpt_dir=${IN_CKPOINT} export output_ckpt_dir=${OUT_CKPOINT} export quantize=True python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
Esegui il motore PyTorch di JetStream localmente
Per eseguire il motore PyTorch di JetStream in locale, imposta il percorso del tokenizzatore:
export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama
Esegui il motore PyTorch JetStream con Llama 7B
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
Esegui il motore PyTorch JetStream con Llama 13b
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
Esegui il server JetStream
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8
NOTA: il parametro --platform=tpu=
deve specificare il numero di dispositivi TPU
(4 per v4-8
e 8 per v5lite-8
). Ad esempio, --platform=tpu=8
.
Dopo aver eseguito run_server.py
, il motore PyTorch di JetStream è pronto per ricevere chiamate gRPC.
Eseguire benchmark
Vai alla cartella deps/JetStream
scaricata quando hai eseguito install_everything.sh
.
cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs
Per ulteriori informazioni, consulta deps/JetStream/benchmarks/README.md
.
Errori tipici
Se ricevi un errore Unexpected keyword argument 'device'
, prova a procedere nel seguente modo:
- Disinstalla le dipendenze di
jax
ejaxlib
- Reinstalla utilizzando
source install_everything.sh
Se ricevi un errore Out of memory
, prova a procedere nel seguente modo:
- Utilizza dimensioni dei batch più piccole
- Utilizza la quantizzazione
Esegui la pulizia
Per evitare che al tuo account Google Cloud vengano addebitati costi relativi alle risorse utilizzate in questo tutorial, elimina il progetto che contiene le risorse oppure mantieni il progetto ed elimina le singole risorse.
Ripulire i repository GitHub
# Clean up the JetStream repository rm -rf JetStream # Clean up the xla repository rm -rf xla
Ripulisci l'ambiente virtuale Python
rm -rf .env
Elimina le risorse TPU
Per ulteriori informazioni, consulta la sezione Eliminare le risorse TPU.