Antes de começar
Siga as etapas em Configurar o ambiente do Cloud TPU para criar um projeto do Google Cloud, ativar a API TPU, instalar o CLI TPU e solicitar a cota do TPU.
Siga as etapas em Criar uma Cloud TPU usando a API CreateNode para
criar uma VM de TPU definindo --accelerator-type
como v5litepod-8
.
Clonar o repositório JetStream e instalar dependências
Conectar-se à VM do TPU usando SSH
- Defina ${TPU_NAME} como o nome da TPU.
- Defina ${PROJECT} como seu projeto do Google Cloud
- Defina ${ZONE} como a zona do Google Cloud em que você vai criar as TPUs.
gcloud compute config-ssh gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
Clonar o repositório do JetStream
git clone https://github.com/google/jetstream-pytorch.git
(Opcional) Crie um ambiente Python virtual usando
venv
ouconda
e ativá-la.Executar o script de instalação
cd jetstream-pytorch source install_everything.sh
Fazer o download e converter pesos
Faça o download dos pesos oficiais do Llama no GitHub.
Converta os pesos.
- Defina ${IN_CKPOINT} como o local que contém os pesos do Llama
- Definir ${OUT_CKPOINT} como um ponto de verificação de gravação de local
export input_ckpt_dir=${IN_CKPOINT} export output_ckpt_dir=${OUT_CKPOINT} export quantize=True python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
Executar o mecanismo PyTorch do JetStream localmente
Para executar o mecanismo PyTorch do JetStream localmente, defina o caminho do tokenizer:
export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama
Executar o mecanismo PyTorch do JetStream com o Llama 7B
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
Executar o mecanismo PyTorch do JetStream com o Llama 13b
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
Executar o servidor do JetStream
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8
OBSERVAÇÃO: o parâmetro --platform=tpu=
precisa especificar o número de dispositivos TPU
(que é 4 para v4-8
e 8 para v5lite-8
). Por exemplo, --platform=tpu=8
.
Depois de executar run_server.py
, o mecanismo PyTorch do JetStream fica pronto para receber chamadas gRPC.
Executar comparativos
Mude para a pasta deps/JetStream
que foi salva durante a execução
install_everything.sh
.
cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs
Para mais informações, consulte deps/JetStream/benchmarks/README.md
.
Erros típicos
Se você receber um erro Unexpected keyword argument 'device'
, tente o seguinte:
- Desinstalar as dependências
jax
ejaxlib
- Reinstalar usando
source install_everything.sh
Se você receber um erro Out of memory
, tente o seguinte:
- Usar um tamanho de lote menor
- Usar a quantização
Limpar
Para evitar cobranças na sua conta do Google Cloud pelos recursos usados no tutorial, exclua o projeto que os contém ou mantenha o projeto e exclua os recursos individuais.
Limpar os repositórios do GitHub
# Clean up the JetStream repository rm -rf JetStream # Clean up the xla repository rm -rf xla
Limpar o ambiente virtual do Python
rm -rf .env
Excluir recursos da TPU
Para mais informações, consulte Excluir recursos de TPU.