Inferência do PyTorch do JetStream na VM do Cloud TPU v5e


O JetStream é um mecanismo com otimização de capacidade e memória para inferência de modelos de linguagem grandes (LLM) em dispositivos XLA (TPUs).

Antes de começar

Siga as etapas em Configurar o ambiente do Cloud TPU para criar um projeto do Google Cloud, ativar a API TPU, instalar a CLI da TPU e solicitar a cota de TPU.

Siga as etapas em Criar uma Cloud TPU usando a API CreateNode para criar uma configuração de VM da TPU --accelerator-type como v5litepod-8.

Clonar o repositório do JetStream e instalar as dependências

  1. Conectar-se à VM da TPU usando SSH

    • Defina ${TPU_NAME} como o nome da TPU.
    • Defina ${PROJECT} como seu projeto do Google Cloud
    • Defina ${ZONE} como a zona do Google Cloud em que as TPUs serão criadas
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. Clonar o repositório do JetStream

       git clone https://github.com/google/jetstream-pytorch.git
    

    Opcional: crie e ative um ambiente Python virtual usando venv ou conda.

  3. Execute o script de instalação

       cd jetstream-pytorch
       source install_everything.sh
    

Fazer o download e converter pesos

  1. Faça o download dos pesos oficiais de Llama no GitHub.

  2. Converta os pesos.

    • Defina ${IN_CKPOINT} como o local que contém os pesos da Llama
    • Defina ${OUT_CKPOINT} como um local de gravação de pontos de verificação
    export input_ckpt_dir=${IN_CKPOINT}
    export output_ckpt_dir=${OUT_CKPOINT}
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

Executar o mecanismo JetStream PyTorch localmente

Para executar o mecanismo JetStream PyTorch localmente, defina o caminho do tokenizador:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Executar o mecanismo JetStream PyTorch com o Llama 7B

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Executar o mecanismo JetStream PyTorch com o Llama 13b

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Executar o servidor do JetStream

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

OBSERVAÇÃO: o parâmetro --platform=tpu= precisa especificar o número de dispositivos TPU (que é 4 para v4-8 e 8 para v5lite-8). Por exemplo, --platform=tpu=8.

Depois de executar run_server.py, o mecanismo JetStream PyTorch está pronto para receber chamadas gRPC.

Executar comparativos

Acesse a pasta deps/JetStream que foi transferida por download quando você executou install_everything.sh.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

Para mais informações, consulte deps/JetStream/benchmarks/README.md.

Erros típicos

Se você receber um erro Unexpected keyword argument 'device', tente o seguinte:

  • Desinstalar as dependências jax e jaxlib
  • Reinstalar usando source install_everything.sh

Se você receber um erro Out of memory, tente o seguinte:

  • Usar um tamanho de lote menor
  • Usar a quantização

Limpar

Para evitar cobranças na sua conta do Google Cloud pelos recursos usados no tutorial, exclua o projeto que os contém ou mantenha o projeto e exclua os recursos individuais.

  1. Limpar os repositórios do GitHub

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Limpar o ambiente virtual do Python

    rm -rf .env
    
  3. Excluir seus recursos de TPU

    Para mais informações, consulte Excluir recursos de TPU.