Antes de comenzar
Sigue los pasos que se indican en Administra recursos de TPU para crear una VM de TPU que establezca --accelerator-type
en v5litepod-8
y conéctate a la VM de TPU.
Configura JetStream y MaxText
Descarga JetStream y el repositorio de GitHub de MaxText
git clone -b jetstream-v0.2.2 https://github.com/google/maxtext.git git clone -b v0.2.2 https://github.com/google/JetStream.git
Configura MaxText
# Create a python virtual environment sudo apt install python3.10-venv python -m venv .env source .env/bin/activate # Set up MaxText cd maxtext/ bash setup.sh
Convierte los puntos de control del modelo
Puedes ejecutar el servidor JetStream MaxText con modelos de Gemma o Llama2. En esta sección, se describe cómo ejecutar el servidor JetStream MaxText con varios tamaños de estos modelos.
Usa un punto de control de un modelo de Gemma
- Descarga un punto de control de Gemma desde Kaggle.
Copia el punto de control en tu bucket de Cloud Storage
# Set YOUR_CKPT_PATH to the path to the checkpoints # Set CHKPT_BUCKET to the Cloud Storage bucket where you copied the checkpoints gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
Para ver un ejemplo que incluye valores para
${YOUR_CKPT_PATH}
y${CHKPT_BUCKET}
, consulta la secuencia de comandos de conversión.Convierte el punto de control de Gemma en un punto de control no analizado compatible con MaxText.
# For gemma-7b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
Usa un punto de control de modelo de Llama2
Descarga un punto de control de Llama2 de la comunidad de código abierto o usa uno que hayas generado.
Copia los puntos de control en tu bucket de Cloud Storage.
gcloud storage cp ${YOUR_CKPT_PATH} ${CHKPT_BUCKET} --recursive
Para ver un ejemplo que incluye valores para
${YOUR_CKPT_PATH}
y${CHKPT_BUCKET}
, consulta la secuencia de comandos de conversión.Convierte el punto de control de Llama2 en un punto de control no analizado compatible con MaxText.
# For llama2-7b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET} # For llama2-13b bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
Ejecuta el servidor de JetStream MaxText
En esta sección, se describe cómo ejecutar el servidor MaxText con un punto de control compatible con MaxText.
Configura las variables de entorno para el servidor MaxText
Exporta las siguientes variables de entorno según el modelo que uses.
Usa el valor de UNSCANNED_CKPT_PATH
del resultado de model_ckpt_conversion.sh
.
Crea variables de entorno de Gemma-7b para las marcas del servidor
Configura las marcas del servidor MaxText de JetStream.
export TOKENIZER_PATH=assets/tokenizer.gemma
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=gemma-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
Crea variables de entorno Llama2-7b para marcas de servidor
Configura las marcas del servidor MaxText de JetStream.
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-7b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=11
Crea variables de entorno Llama2-13b para las marcas del servidor
Configura las marcas del servidor MaxText de JetStream.
export TOKENIZER_PATH=assets/tokenizer.llama2
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
export MAX_PREFILL_PREDICT_LENGTH=1024
export MAX_TARGET_LENGTH=2048
export MODEL_NAME=llama2-13b
export ICI_FSDP_PARALLELISM=1
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
export ICI_TENSOR_PARALLELISM=1
export SCAN_LAYERS=false
export WEIGHT_DTYPE=bfloat16
export PER_DEVICE_BATCH_SIZE=4
Inicia el servidor de JetStream MaxText
cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
Descripciones de las marcas del servidor MaxText de JetStream
tokenizer_path
- La ruta de acceso a un analizador (debe coincidir con tu modelo).
load_parameters_path
- Carga los parámetros (sin estados del optimizador) de un directorio específico.
per_device_batch_size
- tamaño del lote de decodificación por dispositivo (1 chip TPU = 1 dispositivo)
max_prefill_predict_length
- Longitud máxima del relleno previo cuando se realiza la regresión automática
max_target_length
- Longitud máxima de la secuencia
model_name
- Nombre del modelo
ici_fsdp_parallelism
- La cantidad de fragmentos para el paralelismo de FSDP
ici_autoregressive_parallelism
- La cantidad de fragmentos para el paralelismo autorregresivo
ici_tensor_parallelism
- La cantidad de fragmentos para el paralelismo de tensor
weight_dtype
- Tipo de datos de peso (por ejemplo, bfloat16)
scan_layers
- Marca booleana de capas de análisis (establecida en "false" para la inferencia)
Envía una solicitud de prueba al servidor MaxText de JetStream
cd ~
# For Gemma model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.gemma
# For Llama2 model
python JetStream/jetstream/tools/requester.py --tokenizer maxtext/assets/tokenizer.llama2
El resultado será similar al siguiente ejemplo:
Sending request to: 0.0.0.0:9000
Prompt: Today is a good day
Response: to be a fan
Ejecuta comparativas con el servidor JetStream MaxText
Para obtener los mejores resultados de comparativas, habilita la cuantificación (usa puntos de control entrenados o ajustados con precisión de AQT para garantizar la precisión) para las ponderaciones y la caché de KV. Para habilitar la cuantización, establece las marcas de cuantización:
# Enable int8 quantization for both weights and KV cache
export QUANTIZATION=int8
export QUANTIZE_KVCACHE=true
# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance.
export PER_DEVICE_BATCH_SIZE=12
cd ~/maxtext
python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${LOAD_PARAMETERS_PATH} \
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
max_target_length=${MAX_TARGET_LENGTH} \
model_name=${MODEL_NAME} \
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
quantization=${QUANTIZATION} \
quantize_kvcache=${QUANTIZE_KVCACHE}
Comparación de Gemma-7b
Para obtener comparativas de Gemma-7b, haz lo siguiente:
- Descarga el conjunto de datos de ShareGPT.
- Asegúrate de usar el analizador de Gemma (tokenizer.gemma) cuando ejecutes Gemma 7b.
- Agrega la marca
--warmup-first
para tu primera ejecución para activar el servidor.
# Activate the env python virtual environment
cd ~
source .env/bin/activate
# Download the dataset
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.gemma \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled
Cómo realizar comparativas de Llama2 más grandes
# Run the benchmark with the downloaded dataset and the tokenizer in MaxText
# You can control the qps by setting `--request-rate`, the default value is inf.
python JetStream/benchmarks/benchmark_serving.py \
--tokenizer maxtext/assets/tokenizer.llama2 \
--num-prompts 1000 \
--dataset sharegpt \
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
--max-output-length 1024 \
--request-rate 5 \
--warmup-mode sampled
Limpia
Para evitar que se apliquen cargos a tu cuenta de Google Cloud por los recursos usados en este instructivo, borra el proyecto que contiene los recursos o conserva el proyecto y borra los recursos individuales.
# Delete the Cloud Storage buckets
gcloud storage buckets delete ${MODEL_BUCKET}
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
gcloud storage buckets delete ${DATASET_PATH}
# Clean up the MaxText and JetStream repositories.
rm -rf maxtext
rm -rf JetStream
# Clean up the python virtual environment
rm -rf .env