Inferencia de JetStream MaxText en la VM de Cloud TPU v5e


JetStream es un motor con capacidad de procesamiento y memoria optimizada para la inferencia de modelos grandes de lenguaje (LLM) en dispositivos XLA (TPU).

Antes de comenzar

Sigue los pasos que se indican en Administra los recursos de TPU para crear una configuración de VM de TPU --accelerator-type en v5litepod-8 y conectarte a la VM de TPU.

Configura JetStream y MaxText

  1. Descarga JetStream y el repositorio de MaxText en GitHub

       git clone -b jetstream-v0.2.1 https://github.com/google/maxtext.git
       git clone -b v0.2.1 https://github.com/google/JetStream.git
    
  2. Configura MaxText

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

Convierte puntos de control del modelo

Puedes ejecutar el servidor JetStream MaxText con los modelos Gemma o Llama2. En esta sección, se describe cómo ejecutar el servidor MaxText de JetStream con varios tamaños de estos modelos.

Usar un punto de control del modelo Gemma

  1. Descarga un punto de control de Gemma desde Kaggle.
  2. Copia el punto de control en tu bucket de Cloud Storage

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    

    Para ver un ejemplo que incluya valores para ${YOUR_CKPT_PATH} y ${CHKPT_BUCKET}, consulta la secuencia de comandos de conversión.

  3. Convierte el punto de control de Gemma en un punto de control sin analizar compatible con MaxText.

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Usar un punto de control del modelo de Llama2

  1. Descarga un punto de control de Llama2 de la comunidad de código abierto o usa uno que hayas generado.

  2. Copia los puntos de control en tu bucket de Cloud Storage.

       gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}
    

    Si deseas ver un ejemplo que incluya valores para ${YOUR_CKPT_PATH} y ${CHKPT_BUCKET}, consulta la secuencia de comandos de conversión.

  3. Convierte el punto de control de Llama2 en un punto de control compatible sin analizar con MaxText.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

Ejecuta el servidor de JetStream MaxText

En esta sección, se describe cómo ejecutar el servidor de MaxText mediante el uso de un punto de control compatible con MaxText.

Configura las variables de entorno para el servidor MaxText

Exporta las siguientes variables de entorno según el modelo que usas. Usa el valor de UNSCANNED_CKPT_PATH que aparece en el resultado de model_ckpt_conversion.sh.

Crear variables de entorno Gemma-7b para marcas del servidor

Configura las marcas del servidor de JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Crea variables de entorno de Llama2-7b para marcas del servidor

Configura las marcas del servidor de JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Crea las variables de entorno de Llama2-13b para las marcas del servidor

Configura las marcas del servidor de JetStream MaxText.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

Inicia el servidor de JetStream MaxText

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

Descripciones de las marcas del servidor JetStream MaxText

tokenizer_path
La ruta de acceso a un tokenizador (debe coincidir con tu modelo).
load_parameters_path
Carga los parámetros (sin estados del optimizador) de un directorio específico.
per_device_batch_size
decodificación del tamaño del lote por dispositivo (1 chip TPU = 1 dispositivo)
max_prefill_predict_length
Longitud máxima del autocompletado cuando se realiza la regresión automática
max_target_length
Longitud máxima de la secuencia
model_name
Nombre del modelo
ici_fsdp_parallelism
La cantidad de fragmentos para el paralelismo de FSDP
ici_autoregressive_parallelism
La cantidad de fragmentos para el paralelismo autorregresivo
ici_tensor_parallelism
La cantidad de fragmentos para el paralelismo de tensor
weight_dtype
Tipo de datos de ponderación (por ejemplo, bfloat16)
scan_layers
Marca booleana de capas de análisis (configurada en "falso" para inferencia)

Envía una solicitud de prueba al servidor MaxText de JetStream

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

El resultado será similar al siguiente ejemplo:

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

Cómo ejecutar comparativas con el servidor de JetStream MaxText

Para obtener los mejores resultados de comparativas, habilita la cuantización (usa puntos de control ajustados o entrenados con AQT para garantizar la precisión) para las ponderaciones y la caché de KV. Para habilitar la cuantización, configura las marcas de cuantización:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Comparativas de Gemma-7b

Para obtener comparativas de Gemma-7b, haz lo siguiente:

  1. Descarga el conjunto de datos ShareGPT.
  2. Asegúrate de usar el tokenizador de Gemma (tokenizer.gemma) cuando ejecutes Gemma 7b.
  3. Agrega la marca --warmup-first en tu primera ejecución para preparar el servidor.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-first true

Comparativas de Llama2 más grande

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-first true

Limpia

Para evitar que se apliquen cargos a tu cuenta de Google Cloud por los recursos usados en este instructivo, borra el proyecto que contiene los recursos o conserva el proyecto y borra los recursos individuales.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env