Inferencia de JetStream MaxText en la VM de Cloud TPU v5e


JetStream es un motor con capacidad de procesamiento y memoria optimizada para la inferencia de modelos de lenguaje grande (LLM) en dispositivos XLA (TPU).

Antes de comenzar

Sigue los pasos que se indican en Administra los recursos de TPU para crea una configuración de VM de TPU --accelerator-type en v5litepod-8 y conéctate a la VM de TPU.

Configura JetStream y MaxText

  1. Descarga JetStream y el repositorio de MaxText de GitHub

       git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
       git clone -b v0.2.2 https://github.com/google/JetStream.git
    
  2. Cómo configurar MaxText

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

Convierte los puntos de control del modelo

Puedes ejecutar el servidor JetStream MaxText con modelos Gemma o Llama2. En esta sección, se describe cómo ejecutar el servidor JetStream MaxText con varios tamaños de estos modelos.

Usa un punto de control de un modelo de Gemma

  1. Descarga un punto de control de Gemma desde Kaggle.
  2. Copia el punto de control en tu bucket de Cloud Storage

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Para ver un ejemplo que incluya valores para ${YOUR_CKPT_PATH} y ${CHKPT_BUCKET}, consulta la secuencia de comandos de conversiones.

  3. Convertir el punto de control de Gemma en un punto de control no analizado compatible con MaxText

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Usa un punto de control de modelo de Llama2

  1. Descargar un punto de control de Llama2 de la comunidad de código abierto o usar uno que hayas generado.

  2. Copia los puntos de control en tu bucket de Cloud Storage.

       gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Para ver un ejemplo que incluye valores para ${YOUR_CKPT_PATH} y ${CHKPT_BUCKET}, consulta la secuencia de comandos de conversión.

  3. Convierte el punto de control de Llama2 en un punto de control no analizado compatible con MaxText.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

Ejecuta el servidor MaxText de JetStream

En esta sección, se describe cómo ejecutar el servidor de MaxText mediante un servidor compatible con MaxText punto de control.

Configura las variables de entorno para el servidor MaxText

Exporta las siguientes variables de entorno según el modelo que usas. Usa el valor de UNSCANNED_CKPT_PATH del resultado de model_ckpt_conversion.sh.

Crea variables de entorno de Gemma-7b para las marcas del servidor

Configura las marcas del servidor MaxText de JetStream.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Crea variables de entorno de Llama2-7b para las marcas del servidor

Configura las marcas del servidor de JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Crea variables de entorno Llama2-13b para las marcas del servidor

Configura las marcas del servidor MaxText de JetStream.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

Inicia el servidor de JetStream MaxText

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

Descripciones de las marcas del servidor de MaxText de JetStream

tokenizer_path
La ruta a un tokenizador (debe coincidir con tu modelo).
load_parameters_path
Carga los parámetros (sin estados del optimizador) de un directorio específico.
per_device_batch_size
decodificación del tamaño del lote por dispositivo (1 chip TPU = 1 dispositivo)
max_prefill_predict_length
Longitud máxima del completado previo cuando se realiza una autorregresión
max_target_length
Longitud máxima de la secuencia
model_name
Nombre del modelo
ici_fsdp_parallelism
La cantidad de fragmentos para el paralelismo de FSDP
ici_autoregressive_parallelism
La cantidad de fragmentos para el paralelismo autorregresivo
ici_tensor_parallelism
La cantidad de fragmentos para el paralelismo de tensores
weight_dtype
Tipo de datos de peso (por ejemplo, bfloat16)
scan_layers
Marca booleana de capas de escaneo (configurada en "false" para la inferencia)

Envía una solicitud de prueba al servidor de JetStream MaxText

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

El resultado será similar al siguiente ejemplo:

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

Ejecuta comparativas con el servidor JetStream MaxText

Para obtener los mejores resultados de comparativas, habilita la cuantificación (usa puntos de control entrenados o ajustados con precisión de AQT para garantizar la precisión) para las ponderaciones y la caché de KV. Para habilitar quantization, establece las marcas de cuantización:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Comparativas de Gemma-7b

Para comparar Gemma-7b, haz lo siguiente:

  1. Descarga el conjunto de datos de ShareGPT.
  2. Asegúrate de usar el tokenizador de Gemma (tokenizer.gemma) cuando ejecutes Gemma 7b.
  3. Agrega la marca --warmup-first para tu primera ejecución para activar el servidor.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Cómo obtener comparativas de Llama2 más grandes

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Limpia

Para evitar que se apliquen cargos a tu cuenta de Google Cloud por los recursos usados en este instructivo, borra el proyecto que contiene los recursos o conserva el proyecto y borra los recursos individuales.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env