Inferencia de JetStream MaxText en la VM de Cloud TPU v5e


JetStream es un motor con capacidad de procesamiento y memoria optimizada para la inferencia de modelos de lenguaje grandes (LLM) en dispositivos XLA (TPU).

Antes de comenzar

Sigue los pasos que se indican en Administra recursos de TPU para crear una VM de TPU que establezca --accelerator-type en v5litepod-8 y conéctate a la VM de TPU.

Configura JetStream y MaxText

  1. Descarga JetStream y el repositorio de GitHub de MaxText

       git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git
       git clone -b v0.2.2 https://github.com/google/JetStream.git
    
  2. Configura MaxText

       # Create a python virtual environment
       sudo apt install python3.10-venv
       python -m venv .env
       source .env/bin/activate
    
       # Set up MaxText
       cd maxtext/
       bash setup.sh
    

Convierte los puntos de control del modelo

Puedes ejecutar el servidor JetStream MaxText con modelos de Gemma o Llama2. En esta sección, se describe cómo ejecutar el servidor JetStream MaxText con varios tamaños de estos modelos.

Usa un punto de control de un modelo de Gemma

  1. Descarga un punto de control de Gemma desde Kaggle.
  2. Copia el punto de control en tu bucket de Cloud Storage

        # Set YOUR_CKPT_PATH to the path to the checkpoints
        # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints
        gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Para ver un ejemplo que incluye valores para ${YOUR_CKPT_PATH} y ${CHKPT_BUCKET}, consulta la secuencia de comandos de conversión.

  3. Convierte el punto de control de Gemma en un punto de control no analizado compatible con MaxText.

       # For gemma-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
    

Usa un punto de control de modelo de Llama2

  1. Descarga un punto de control de Llama2 de la comunidad de código abierto o usa uno que hayas generado.

  2. Copia los puntos de control en tu bucket de Cloud Storage.

       gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
    

    Para ver un ejemplo que incluye valores para ${YOUR_CKPT_PATH} y ${CHKPT_BUCKET}, consulta la secuencia de comandos de conversión.

  3. Convierte el punto de control de Llama2 en un punto de control no analizado compatible con MaxText.

       # For llama2-7b
       bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
    
       # For llama2-13b
      bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
    

Ejecuta el servidor de JetStream MaxText

En esta sección, se describe cómo ejecutar el servidor MaxText con un punto de control compatible con MaxText.

Configura las variables de entorno para el servidor MaxText

Exporta las siguientes variables de entorno según el modelo que uses. Usa el valor de UNSCANNED_CKPT_PATH del resultado de model_ckpt_conversion.sh.

Crea variables de entorno de Gemma-7b para las marcas del servidor

Configura las marcas del servidor MaxText de JetStream.

export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Crea variables de entorno Llama2-7b para marcas de servidor

Configura las marcas del servidor MaxText de JetStream.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11

Crea variables de entorno Llama2-13b para las marcas del servidor

Configura las marcas del servidor MaxText de JetStream.

export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4

Inicia el servidor de JetStream MaxText

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE}

Descripciones de las marcas del servidor MaxText de JetStream

tokenizer_path
La ruta de acceso a un analizador (debe coincidir con tu modelo).
load_parameters_path
Carga los parámetros (sin estados del optimizador) de un directorio específico.
per_device_batch_size
tamaño del lote de decodificación por dispositivo (1 chip TPU = 1 dispositivo)
max_prefill_predict_length
Longitud máxima del relleno previo cuando se realiza la regresión automática
max_target_length
Longitud máxima de la secuencia
model_name
Nombre del modelo
ici_fsdp_parallelism
La cantidad de fragmentos para el paralelismo de FSDP
ici_autoregressive_parallelism
La cantidad de fragmentos para el paralelismo autorregresivo
ici_tensor_parallelism
La cantidad de fragmentos para el paralelismo de tensor
weight_dtype
Tipo de datos de peso (por ejemplo, bfloat16)
scan_layers
Marca booleana de capas de análisis (establecida en "false" para la inferencia)

Envía una solicitud de prueba al servidor MaxText de JetStream

cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2

El resultado será similar al siguiente ejemplo:

Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response:  to be a fan

Ejecuta comparativas con el servidor JetStream MaxText

Para obtener los mejores resultados de comparativas, habilita la cuantificación (usa puntos de control entrenados o ajustados con precisión de AQT para garantizar la precisión) para las ponderaciones y la caché de KV. Para habilitar la cuantización, establece las marcas de cuantización:

# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true

# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance. 
export PER_DEVICE_BATCH_SIZE=12

cd ~/maxtext
python MaxText/maxengine_server.py \
  MaxText/configs/base.yml \
  tokenizer_path=${TOKENIZER_PATH} \
  load_parameters_path=${LOAD_PARAMETERS_PATH} \
  max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
  max_target_length=${MAX_TARGET_LENGTH} \
  model_name=${MODEL_NAME} \
  ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
  ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
  ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
  scan_layers=${SCAN_LAYERS} \
  weight_dtype=${WEIGHT_DTYPE} \
  per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
  quantization=${QUANTIZATION} \
  quantize_kvcache=${QUANTIZE_KVCACHE}

Comparación de Gemma-7b

Para obtener comparativas de Gemma-7b, haz lo siguiente:

  1. Descarga el conjunto de datos de ShareGPT.
  2. Asegúrate de usar el analizador de Gemma (tokenizer.gemma) cuando ejecutes Gemma 7b.
  3. Agrega la marca --warmup-first para tu primera ejecución para activar el servidor.
# Activate the env python virtual environment
cd ~
source .env/bin/activate

# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Cómo realizar comparativas de Llama2 más grandes

# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.

python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000  \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled

Limpia

Para evitar que se apliquen cargos a tu cuenta de Google Cloud por los recursos usados en este instructivo, borra el proyecto que contiene los recursos o conserva el proyecto y borra los recursos individuales.

# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}

# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream

# Clean up the python virtual environment
rm -rf .env