Inferencia de JetStream PyTorch en la VM de Cloud TPU v5e


JetStream es un motor con capacidad de procesamiento y memoria optimizada para la inferencia de modelos de lenguaje grandes (LLM) en dispositivos XLA (TPU).

Antes de comenzar

Sigue los pasos que se indican en Configura el entorno de Cloud TPU para crear un proyecto de Google Cloud, activar la API de TPU, instalar TPU CLI y solicitar una cuota de TPU.

Sigue los pasos que se indican en Cómo crear una Cloud TPU con la API de CreateNode para crear una VM de TPU que establezca --accelerator-type en v5litepod-8.

Clona el repositorio de JetStream y, luego, instala las dependencias

  1. Conéctate a tu VM de TPU mediante SSH

    • Establece ${TPU_NAME} en el nombre de tu TPU.
    • Configura ${PROJECT} en tu proyecto de Google Cloud
    • Establece ${ZONE} en la zona de Google Cloud en la que deseas crear tus TPU.
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. Clona el repositorio de JetStream

       git clone https://github.com/google/jetstream-pytorch.git
    

    Crea un entorno virtual de Python con venv o conda y actívalo (opcional).

  3. Ejecuta la secuencia de comandos de instalación

       cd jetstream-pytorch
       source install_everything.sh
    

Descarga y convierte pesos

  1. Descarga los pesos oficiales de Llama desde GitHub.

  2. Convierte los pesos.

    • Establece ${IN_CKPOINT} en la ubicación que contiene los pesos de Llama.
    • Establece ${OUT_CKPOINT} en un punto de control de escritura de ubicación.
    export input_ckpt_dir=${IN_CKPOINT} 
    export output_ckpt_dir=${OUT_CKPOINT} 
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

Ejecuta el motor de PyTorch de JetStream de forma local

Para ejecutar el motor de PyTorch de JetStream de forma local, configura la ruta de acceso del analizador:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Ejecuta el motor de PyTorch de JetStream con Llama 7B

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Ejecuta el motor de PyTorch de JetStream con Llama 13b

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Ejecuta el servidor de JetStream

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

NOTA: El parámetro --platform=tpu= debe especificar la cantidad de dispositivos TPU (que es 4 para v4-8 y 8 para v5lite-8). Por ejemplo, --platform=tpu=8.

Después de ejecutar run_server.py, el motor de PyTorch de JetStream está listo para recibir llamadas de gRPC.

Ejecuta comparativas

Cambia a la carpeta deps/JetStream que se descargó cuando ejecutaste install_everything.sh.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

Para obtener más información, consulta deps/JetStream/benchmarks/README.md.

Errores típicos

Si recibes un error Unexpected keyword argument 'device', prueba lo siguiente:

  • Desinstala las dependencias de jax y jaxlib
  • Vuelve a instalar con source install_everything.sh

Si recibes un error Out of memory, prueba lo siguiente:

  • Usa un tamaño de lote más pequeño
  • Usa la cuantización

Limpia

Para evitar que se apliquen cargos a tu cuenta de Google Cloud por los recursos usados en este instructivo, borra el proyecto que contiene los recursos o conserva el proyecto y borra los recursos individuales.

  1. Limpia los repositorios de GitHub

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Limpia el entorno virtual de Python

    rm -rf .env
    
  3. Borra tus recursos de TPU

    Para obtener más información, consulta Cómo borrar tus recursos de TPU.