Inferência do PyTorch do JetStream na VM do Cloud TPU v5e


O JetStream é um mecanismo otimizado para capacidade de processamento e memória para inferência de modelos de linguagem grandes (LLMs) em dispositivos XLA (TPUs).

Antes de começar

Siga as etapas em Configurar o ambiente do Cloud TPU para criar um projeto do Google Cloud, ativar a API TPU, instalar o CLI TPU e solicitar a cota do TPU.

Siga as etapas em Criar uma Cloud TPU usando a API CreateNode para criar uma VM de TPU definindo --accelerator-type como v5litepod-8.

Clonar o repositório JetStream e instalar dependências

  1. Conectar-se à VM do TPU usando SSH

    • Defina ${TPU_NAME} como o nome da TPU.
    • Defina ${PROJECT} como seu projeto do Google Cloud
    • Defina ${ZONE} como a zona do Google Cloud em que você vai criar as TPUs.
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. Clonar o repositório JetStream

       git clone https://github.com/google/jetstream-pytorch.git
    

    (Opcional) Crie e ative um ambiente virtual do Python usando venv ou conda.

  3. Executar o script de instalação

       cd jetstream-pytorch
       source install_everything.sh
    

Fazer o download e converter pesos

  1. Faça o download dos pesos oficiais do Llama no GitHub.

  2. Converta os pesos.

    • Defina ${IN_CKPOINT} como o local que contém os pesos da lhama
    • Definir ${OUT_CKPOINT} como um ponto de verificação de gravação de local
    export input_ckpt_dir=${IN_CKPOINT} 
    export output_ckpt_dir=${OUT_CKPOINT} 
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

Executar o mecanismo PyTorch do JetStream localmente

Para executar o mecanismo PyTorch do JetStream localmente, defina o caminho do tokenizer:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Executar o mecanismo PyTorch do JetStream com o Llama 7B

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Executar o mecanismo PyTorch do JetStream com o Llama 13b

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Executar o servidor do JetStream

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

OBSERVAÇÃO: o parâmetro --platform=tpu= precisa especificar o número de dispositivos TPU, que é 4 para v4-8 e 8 para v5lite-8. Por exemplo, --platform=tpu=8.

Depois de executar run_server.py, o mecanismo PyTorch do JetStream fica pronto para receber chamadas gRPC.

Executar comparativos

Mude para a pasta deps/JetStream que foi transferida por download quando você executou install_everything.sh.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

Para mais informações, consulte deps/JetStream/benchmarks/README.md.

Erros típicos

Se você receber um erro Unexpected keyword argument 'device', tente o seguinte:

  • Desinstalar as dependências jax e jaxlib
  • Reinstalar usando source install_everything.sh

Se você receber um erro Out of memory, tente o seguinte:

  • Usar tamanhos de lote menores
  • Usar a quantização

Limpar

Para evitar cobranças na sua conta do Google Cloud pelos recursos usados no tutorial, exclua o projeto que os contém ou mantenha o projeto e exclua os recursos individuais.

  1. Limpar os repositórios do GitHub

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Limpar o ambiente virtual do Python

    rm -rf .env
    
  3. Excluir recursos da TPU

    Para mais informações, consulte Excluir seus recursos de TPU.