Inferencia de MaxText de JetStream en VMs de TPU v6e
En este tutorial se muestra cómo usar JetStream para servir modelos MaxText en TPU v6e. JetStream es un motor optimizado de rendimiento y memoria para la inferencia de modelos de lenguaje extensos (LLMs) en dispositivos XLA (TPUs). En este tutorial, ejecutarás la prueba comparativa de inferencia del modelo Llama2-7B.
Antes de empezar
Prepara el aprovisionamiento de una TPU v6e con 4 chips:
Sigue la guía Configurar el entorno de TPU de Cloud para configurar un proyecto de Google Cloud , configurar la CLI de Google Cloud, habilitar la API Cloud TPU y asegurarte de que tienes acceso para usar las TPU de Cloud.
Autentícate con Google Cloud y configura el proyecto y la zona predeterminados de la CLI de Google Cloud.
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
Capacidad segura
Cuando quieras proteger la capacidad de las TPUs, consulta Cuotas de TPUs de Cloud para obtener más información sobre las cuotas de TPUs de Cloud. Si tienes más preguntas sobre cómo proteger la capacidad, ponte en contacto con el equipo de Ventas o de Cuentas de Cloud TPU.
Aprovisionar el entorno de TPU de Cloud
Puedes aprovisionar VMs de TPU con GKE, con GKE y XPK o como recursos en cola.
Requisitos previos
- Comprueba que tu proyecto tenga suficiente cuota de
TPUS_PER_TPU_FAMILY
, que especifica el número máximo de chips a los que puedes acceder en tuGoogle Cloud proyecto. - Comprueba que tu proyecto tenga suficiente cuota de TPU para lo siguiente:
- Cuota de máquinas virtuales de TPU
- Cuota de direcciones IP
- Cuota de Hyperdisk Balanced
- Permisos de proyecto de usuario
- Si usas GKE con XPK, consulta los permisos de la consola de Cloud en la cuenta de usuario o de servicio para saber qué permisos necesitas para ejecutar XPK.
Crear variables de entorno
En Cloud Shell, crea las siguientes variables de entorno:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-b export ACCELERATOR_TYPE=v6e-4 export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descripciones de las variables de entorno
Variable | Descripción |
---|---|
PROJECT_ID |
El ID de tu proyecto Google Cloud . Usa un proyecto que ya tengas o crea uno. |
TPU_NAME |
El nombre de la TPU. |
ZONE |
La zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU. |
ACCELERATOR_TYPE |
El tipo de acelerador especifica la versión y el tamaño de la TPU de Cloud que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU. |
RUNTIME_VERSION |
La versión de software de la TPU de Cloud. |
SERVICE_ACCOUNT |
La dirección de correo de tu cuenta de servicio. Para encontrarlo, ve a la
página Cuentas de servicio de la
consola Google Cloud .
Por ejemplo:
|
QUEUED_RESOURCE_ID |
ID de texto asignado por el usuario de la solicitud de recurso en cola. |
Aprovisionar una TPU v6e
Usa el siguiente comando para aprovisionar una TPU v6e:
gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Usa los comandos list
o describe
para consultar el estado del recurso en cola.
gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE}
Para obtener más información sobre los estados de las solicitudes de recursos en cola, consulta Gestionar recursos en cola.
Conectarse a la TPU mediante SSH
gcloud compute tpus tpu-vm ssh ${TPU_NAME}
Una vez que te hayas conectado a la TPU, podrás ejecutar la prueba comparativa de inferencia.
Configurar el entorno de la VM de TPU
Crea un directorio para ejecutar la prueba comparativa de inferencia:
export MAIN_DIR=your-main-directory mkdir -p ${MAIN_DIR}
Configura un entorno virtual de Python:
cd ${MAIN_DIR} sudo apt update sudo apt install python3.10 python3.10-venv python3.10 -m venv venv source venv/bin/activate
Instala Git Large File Storage (LFS) (para los datos de OpenOrca):
sudo apt-get install git-lfs git lfs install
Clona e instala JetStream:
cd $MAIN_DIR git clone https://github.com/google/JetStream.git cd JetStream git checkout main pip install -e . cd benchmarks pip install -r requirements.in
Configura MaxText:
cd $MAIN_DIR git clone https://github.com/google/maxtext.git cd maxtext git checkout main bash setup.sh pip install torch --index-url https://download.pytorch.org/whl/cpu
Solicita acceso a los modelos Llama para obtener una clave de descarga de Meta para el modelo Llama 2.
Clona el repositorio de Llama:
cd $MAIN_DIR git clone https://github.com/meta-llama/llama cd llama
Ejecuta
bash download.sh
. Cuando se te solicite, proporciona tu clave de descarga. Esta secuencia de comandos crea un directoriollama-2-7b
dentro del directoriollama
.bash download.sh
Crea segmentos de almacenamiento:
export CHKPT_BUCKET=gs://your-checkpoint-bucket export BASE_OUTPUT_DIRECTORY=gs://your-output-dir export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data gcloud storage buckets create ${CHKPT_BUCKET} gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} gcloud storage buckets create ${CONVERTED_CHECKPOINT_PATH} gcloud storage buckets create ${MAXTEXT_BUCKET_UNSCANNED} gcloud storage cp --recursive llama-2-7b/* ${CHKPT_BUCKET}
Realizar la conversión de puntos de control
Realiza la conversión a los puntos de control escaneados:
cd $MAIN_DIR/maxtext python3 -m MaxText.llama_or_mistral_ckpt \ --base-model-path $MAIN_DIR/llama/llama-2-7b \ --model-size llama2-7b \ --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
Para convertir los puntos de control en puntos de control no escaneados, sigue estos pasos:
export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint python3 -m MaxText.generate_param_only_checkpoint \ MaxText/configs/base.yml \ base_output_directory=${MAXTEXT_BUCKET_UNSCANNED} \ load_parameters_path=${CONVERTED_CHECKPOINT} \ run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} \ model_name='llama2-7b' \ force_unroll=true
Realizar inferencias
Ejecuta una prueba de validación:
export UNSCANNED_CKPT_PATH=${MAXTEXT_BUCKET_UNSCANNED}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items python3 -m MaxText.decode \ MaxText/configs/base.yml \ load_parameters_path=${UNSCANNED_CKPT_PATH} \ run_name=runner_decode_unscanned_${idx} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ per_device_batch_size=1 \ model_name='llama2-7b' \ ici_autoregressive_parallelism=4 \ max_prefill_predict_length=4 \ max_target_length=16 \ prompt="I love to" \ attention=dot_product \ scan_layers=false
Ejecuta el servidor en el terminal actual:
export TOKENIZER_PATH=assets/tokenizer.llama2 export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH} export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=llama2-7b export ICI_FSDP_PARALLELISM=1 export ICI_AUTOREGRESSIVE_PARALLELISM=1 export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=11 cd $MAIN_DIR/maxtext python3 -m MaxText.maxengine_server \ MaxText/configs/base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${LOAD_PARAMETERS_PATH} \ max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ max_target_length=${MAX_TARGET_LENGTH} \ model_name=${MODEL_NAME} \ ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ scan_layers=${SCAN_LAYERS} \ weight_dtype=${WEIGHT_DTYPE} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
Abre una nueva ventana de terminal, conéctate a la TPU y cambia al mismo entorno virtual que usaste en la primera ventana de terminal:
source venv/bin/activate
Ejecuta los siguientes comandos para ejecutar la prueba comparativa de JetStream.
export MAIN_DIR=your-main-directory cd $MAIN_DIR python JetStream/benchmarks/benchmark_serving.py \ --tokenizer $MAIN_DIR/maxtext/assets/tokenizer.llama2 \ --warmup-mode sampled \ --save-result \ --save-request-outputs \ --request-outputs-file-path outputs.json \ --num-prompts 1000 \ --max-output-length 1024 \ --dataset openorca \ --dataset-path $MAIN_DIR/JetStream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl
Resultados
Se generó el siguiente resultado al ejecutar la prueba de rendimiento con v6e-8. Los resultados variarán en función del hardware, el software, el modelo y la red.
Mean output size: 929.5959798994975
Median output size: 1026.0
P99 output size: 1026.0
Successful requests: 995
Benchmark duration: 195.533269 s
Total input tokens: 217011
Total generated tokens: 924948
Request throughput: 5.09 requests/s
Input token throughput: 1109.84 tokens/s
Output token throughput: 4730.39 tokens/s
Overall token throughput: 5840.23 tokens/s
Mean ttft: 538.49 ms
Median ttft: 95.66 ms
P99 ttft: 13937.86 ms
Mean ttst: 1218.72 ms
Median ttst: 152.57 ms
P99 ttst: 14241.30 ms
Mean TPOT: 91.83 ms
Median TPOT: 16.63 ms
P99 TPOT: 363.37 ms
Limpieza
Desconecta la TPU:
$ (vm) exit
Elimina la TPU:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --force \ --async
Elimina los segmentos y su contenido:
export CHKPT_BUCKET=gs://your-checkpoint-bucket export BASE_OUTPUT_DIRECTORY=gs://your-output-dir export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data gcloud storage rm -r ${CHKPT_BUCKET} gcloud storage rm -r ${BASE_OUTPUT_DIRECTORY} gcloud storage rm -r ${CONVERTED_CHECKPOINT_PATH} gcloud storage rm -r ${MAXTEXT_BUCKET_UNSCANNED}