Inferencia de JetStream PyTorch en la VM de Cloud TPU v5e


JetStream es un motor con capacidad de procesamiento y memoria optimizada para la inferencia de modelos grandes de lenguaje (LLM) en dispositivos XLA (TPU).

Antes de comenzar

Sigue los pasos que se indican en Configura el entorno de Cloud TPU para crear un proyecto de Google Cloud, activar la API de TPU, instalar la CLI de TPU y solicitar una cuota de TPU.

Sigue los pasos que se indican en Crea una Cloud TPU con la API de CreateNode para crear una configuración de VM de TPU --accelerator-type como v5litepod-8.

Clona el repositorio de JetStream y, luego, instala las dependencias

  1. Conéctate a la VM de TPU con SSH

    • Configura ${TPU_NAME} con el nombre de tu TPU.
    • Configura ${PROJECT} para tu proyecto de Google Cloud.
    • Configura ${ZONE} en la zona de Google Cloud en la que se crearán las TPU
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. Clona el repositorio de JetStream

       git clone https://github.com/google/jetstream-pytorch.git
    

    (Opcional) Crea un entorno virtual de Python con venv o conda y actívalo.

  3. Ejecuta la secuencia de comandos de instalación

       cd jetstream-pytorch
       source install_everything.sh
    

Descarga y convierte pesos

  1. Descarga los pesos oficiales de Llama desde GitHub.

  2. Convierte los pesos.

    • Establece ${IN_CKPOINT} en la ubicación que contiene los pesos de las llamas.
    • Establecer ${OUT_CKPOINT} en puntos de control de escritura de ubicación
    export input_ckpt_dir=${IN_CKPOINT}
    export output_ckpt_dir=${OUT_CKPOINT}
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

Ejecuta el motor JetStream PyTorch de forma local

Para ejecutar el motor JetStream PyTorch de forma local, establece la ruta de acceso del tokenizador:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Ejecuta el motor JetStream PyTorch con Llama 7B

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Ejecuta el motor JetStream PyTorch con Llama 13b

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Ejecuta el servidor de JetStream

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

NOTA: El parámetro --platform=tpu= debe especificar la cantidad de dispositivos de TPU (que es 4 para v4-8 y 8 para v5lite-8). Por ejemplo, --platform=tpu=8.

Después de ejecutar run_server.py, el motor JetStream PyTorch está listo para recibir llamadas de gRPC.

Ejecuta comparativas

Cambia a la carpeta deps/JetStream que se descargó cuando ejecutaste install_everything.sh.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

Para obtener más información, consulta deps/JetStream/benchmarks/README.md.

Errores típicos

Si recibes un error Unexpected keyword argument 'device', prueba lo siguiente:

  • Desinstala las dependencias jax y jaxlib
  • Reinstalar con source install_everything.sh

Si recibes un error Out of memory, prueba lo siguiente:

  • Usar un tamaño de lote más pequeño
  • Usa la cuantización

Limpia

Para evitar que se apliquen cargos a tu cuenta de Google Cloud por los recursos usados en este instructivo, borra el proyecto que contiene los recursos o conserva el proyecto y borra los recursos individuales.

  1. Limpia los repositorios de GitHub

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Limpia el entorno virtual de Python

    rm -rf .env
    
  3. Borra tus recursos de TPU

    Para obtener más información, consulta Borra tus recursos de TPU.