Antes de começar
Siga as etapas em Configurar o ambiente do Cloud TPU para criar um projeto do Google Cloud, ativar a API TPU, instalar a CLI da TPU e solicitar a cota de TPU.
Siga as etapas em Criar uma Cloud TPU usando a API CreateNode para criar uma configuração de VM da TPU --accelerator-type
como v5litepod-8
.
Clonar o repositório do JetStream e instalar as dependências
Conectar-se à VM da TPU usando SSH
- Defina ${TPU_NAME} como o nome da TPU.
- Defina ${PROJECT} como seu projeto do Google Cloud
- Defina ${ZONE} como a zona do Google Cloud em que as TPUs serão criadas
gcloud compute config-ssh gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
Clonar o repositório do JetStream
git clone https://github.com/google/jetstream-pytorch.git
Opcional: crie e ative um ambiente Python virtual usando
venv
ouconda
.Execute o script de instalação
cd jetstream-pytorch source install_everything.sh
Fazer o download e converter pesos
Faça o download dos pesos oficiais de Llama no GitHub.
Converta os pesos.
- Defina ${IN_CKPOINT} como o local que contém os pesos da Llama
- Defina ${OUT_CKPOINT} como um local de gravação de pontos de verificação
export input_ckpt_dir=${IN_CKPOINT} export output_ckpt_dir=${OUT_CKPOINT} export quantize=True python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
Executar o mecanismo JetStream PyTorch localmente
Para executar o mecanismo JetStream PyTorch localmente, defina o caminho do tokenizador:
export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama
Executar o mecanismo JetStream PyTorch com o Llama 7B
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
Executar o mecanismo JetStream PyTorch com o Llama 13b
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
Executar o servidor do JetStream
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8
OBSERVAÇÃO: o parâmetro --platform=tpu=
precisa especificar o número de dispositivos TPU (que é 4 para v4-8
e 8 para v5lite-8
). Por exemplo, --platform=tpu=8
.
Depois de executar run_server.py
, o mecanismo JetStream PyTorch está pronto para receber chamadas gRPC.
Executar comparativos
Acesse a pasta deps/JetStream
que foi transferida por download quando você executou
install_everything.sh
.
cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs
Para mais informações, consulte deps/JetStream/benchmarks/README.md
.
Erros típicos
Se você receber um erro Unexpected keyword argument 'device'
, tente o seguinte:
- Desinstalar as dependências
jax
ejaxlib
- Reinstalar usando
source install_everything.sh
Se você receber um erro Out of memory
, tente o seguinte:
- Usar um tamanho de lote menor
- Usar a quantização
Limpar
Para evitar cobranças na sua conta do Google Cloud pelos recursos usados no tutorial, exclua o projeto que os contém ou mantenha o projeto e exclua os recursos individuais.
Limpar os repositórios do GitHub
# Clean up the JetStream repository rm -rf JetStream # Clean up the xla repository rm -rf xla
Limpar o ambiente virtual do Python
rm -rf .env
Excluir seus recursos de TPU
Para mais informações, consulte Excluir recursos de TPU.