Inférence JetStream MaxText sur des VM TPU v6e
Ce tutoriel explique comment utiliser JetStream pour mettre en service des modèles MaxText sur TPU v6e. JetStream est un moteur optimisé pour le débit et la mémoire pour l'inférence de grands modèles de langage (LLM) sur les appareils XLA (TPU). Dans ce tutoriel, vous allez exécuter le benchmark d'inférence pour le modèle Llama2-7B.
Avant de commencer
Préparez-vous à provisionner un TPU v6e avec quatre puces :
Suivez le guide Configurer l'environnement Cloud TPU pour configurer un projet Google Cloud , configurer la Google Cloud CLI, activer l'API Cloud TPU et vous assurer d'avoir accès à Cloud TPU.
Authentifiez-vous avec Google Cloud et configurez le projet et la zone par défaut pour Google Cloud CLI.
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
Sécuriser la capacité
Lorsque vous êtes prêt à sécuriser la capacité des TPU, consultez Quotas Cloud TPU pour en savoir plus sur les quotas Cloud TPU. Si vous avez d'autres questions sur la sécurisation de la capacité, contactez votre équipe commerciale ou l'équipe chargée du compte Cloud TPU.
Provisionner l'environnement Cloud TPU
Vous pouvez provisionner des VM TPU avec GKE, avec GKE et XPK, ou en tant que ressources en file d'attente.
Prérequis
- Vérifiez que votre projet dispose d'un quota
TPUS_PER_TPU_FAMILYsuffisant, qui spécifie le nombre maximal de puces auxquelles vous pouvez accéder dans votre projetGoogle Cloud . - Vérifiez que votre projet dispose d'un quota TPU suffisant pour :
- Quota de VM TPU
- Quota d'adresses IP
- Quota Hyperdisk Balanced
- Autorisations liées au projet utilisateur
- Si vous utilisez GKE avec XPK, consultez Autorisations de la console Cloud sur le compte utilisateur ou de service pour connaître les autorisations nécessaires à l'exécution de XPK.
Créer des variables d'environnement
Dans Cloud Shell, créez les variables d'environnement suivantes :
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east5-b export ACCELERATOR_TYPE=v6e-4 export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id
Descriptions des variables d'environnement
| Variable | Description |
|---|---|
PROJECT_ID |
ID de votre projet Google Cloud . Utilisez un projet existant ou créez-en un. |
TPU_NAME |
Nom du TPU. |
ZONE |
Zone dans laquelle créer la VM TPU. Pour en savoir plus sur les zones compatibles, consultez Régions et zones TPU. |
ACCELERATOR_TYPE |
Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez Versions de TPU. |
RUNTIME_VERSION |
Version logicielle de Cloud TPU. |
SERVICE_ACCOUNT |
Adresse e-mail de votre compte de service. Pour le trouver, accédez à la page Comptes de service dans la console Google Cloud .
Par exemple :
|
QUEUED_RESOURCE_ID |
ID de texte attribué par l'utilisateur de la demande de ressource en file d'attente. |
Provisionner un TPU v6e
Exécutez la commande suivante pour provisionner un TPU v6e :
gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id=${TPU_NAME} \ --project=${PROJECT_ID} \ --zone=${ZONE} \ --accelerator-type=${ACCELERATOR_TYPE} \ --runtime-version=${RUNTIME_VERSION} \ --service-account=${SERVICE_ACCOUNT}
Utilisez les commandes list ou describe pour interroger l'état de votre ressource en file d'attente.
gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} --zone ${ZONE}
Pour en savoir plus sur l'état des demandes de ressources en file d'attente, consultez Gérer des ressources en file d'attente.
Se connecter au TPU à l'aide de SSH
gcloud compute tpus tpu-vm ssh ${TPU_NAME}
Une fois que vous êtes connecté au TPU, vous pouvez exécuter le benchmark d'inférence.
Configurer votre environnement de VM TPU
Créez un répertoire pour exécuter le benchmark d'inférence :
export MAIN_DIR=your-main-directory mkdir -p ${MAIN_DIR}
Configurez un environnement virtuel Python :
cd ${MAIN_DIR} sudo apt update sudo apt install python3.10 python3.10-venv python3.10 -m venv venv source venv/bin/activate
Installez Git Large File Storage (LFS) (pour les données OpenOrca) :
sudo apt-get install git-lfs git lfs install
Clonez et installez JetStream :
cd $MAIN_DIR git clone https://github.com/google/JetStream.git cd JetStream git checkout main pip install -e . cd benchmarks pip install -r requirements.in
Configurez MaxText :
cd $MAIN_DIR git clone https://github.com/google/maxtext.git cd maxtext git checkout main bash setup.sh pip install torch --index-url https://download.pytorch.org/whl/cpu
Demandez l'accès aux modèles Llama pour obtenir une clé de téléchargement de Meta pour le modèle Llama 2.
Clonez le dépôt Llama :
cd $MAIN_DIR git clone https://github.com/meta-llama/llama cd llama
Exécutez
bash download.sh. Lorsque vous y êtes invité, fournissez votre clé de téléchargement. Ce script crée un répertoirellama-2-7bdans votre répertoirellama.bash download.shCréez des buckets de stockage :
export CHKPT_BUCKET=gs://your-checkpoint-bucket export BASE_OUTPUT_DIRECTORY=gs://your-output-dir export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data gcloud storage buckets create ${CHKPT_BUCKET} gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} gcloud storage buckets create ${CONVERTED_CHECKPOINT_PATH} gcloud storage buckets create ${MAXTEXT_BUCKET_UNSCANNED} gcloud storage cp --recursive llama-2-7b/* ${CHKPT_BUCKET}
Effectuer la conversion des points de contrôle
Effectuer la conversion en points de contrôle analysés :
cd $MAIN_DIR/maxtext python3 -m MaxText.llama_or_mistral_ckpt \ --base-model-path $MAIN_DIR/llama/llama-2-7b \ --model-size llama2-7b \ --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
Convertir en points de contrôle non analysés :
export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint python3 -m MaxText.generate_param_only_checkpoint \ MaxText/configs/base.yml \ base_output_directory=${MAXTEXT_BUCKET_UNSCANNED} \ load_parameters_path=${CONVERTED_CHECKPOINT} \ run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} \ model_name='llama2-7b' \ force_unroll=true
Effectuer une inférence
Exécutez un test de validation :
export UNSCANNED_CKPT_PATH=${MAXTEXT_BUCKET_UNSCANNED}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items python3 -m MaxText.decode \ MaxText/configs/base.yml \ load_parameters_path=${UNSCANNED_CKPT_PATH} \ run_name=runner_decode_unscanned_${idx} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ per_device_batch_size=1 \ model_name='llama2-7b' \ ici_autoregressive_parallelism=4 \ max_prefill_predict_length=4 \ max_target_length=16 \ prompt="I love to" \ attention=dot_product \ scan_layers=false
Exécutez le serveur dans votre terminal actuel :
export TOKENIZER_PATH=assets/tokenizer.llama2 export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH} export MAX_PREFILL_PREDICT_LENGTH=1024 export MAX_TARGET_LENGTH=2048 export MODEL_NAME=llama2-7b export ICI_FSDP_PARALLELISM=1 export ICI_AUTOREGRESSIVE_PARALLELISM=1 export ICI_TENSOR_PARALLELISM=-1 export SCAN_LAYERS=false export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=11 cd $MAIN_DIR/maxtext python3 -m MaxText.maxengine_server \ MaxText/configs/base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${LOAD_PARAMETERS_PATH} \ max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ max_target_length=${MAX_TARGET_LENGTH} \ model_name=${MODEL_NAME} \ ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ scan_layers=${SCAN_LAYERS} \ weight_dtype=${WEIGHT_DTYPE} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
Ouvrez une nouvelle fenêtre de terminal, connectez-vous au TPU et basculez sur le même environnement virtuel que celui que vous avez utilisé dans la première fenêtre de terminal :
source venv/bin/activateExécutez les commandes suivantes pour exécuter le benchmark JetStream.
export MAIN_DIR=your-main-directory cd $MAIN_DIR python JetStream/benchmarks/benchmark_serving.py \ --tokenizer $MAIN_DIR/maxtext/assets/tokenizer.llama2 \ --warmup-mode sampled \ --save-result \ --save-request-outputs \ --request-outputs-file-path outputs.json \ --num-prompts 1000 \ --max-output-length 1024 \ --dataset openorca \ --dataset-path $MAIN_DIR/JetStream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl
Résultats
Le résultat suivant a été généré lors de l'exécution du benchmark avec v6e-8. Les résultats varient en fonction du matériel, des logiciels, du modèle et de la mise en réseau.
Mean output size: 929.5959798994975
Median output size: 1026.0
P99 output size: 1026.0
Successful requests: 995
Benchmark duration: 195.533269 s
Total input tokens: 217011
Total generated tokens: 924948
Request throughput: 5.09 requests/s
Input token throughput: 1109.84 tokens/s
Output token throughput: 4730.39 tokens/s
Overall token throughput: 5840.23 tokens/s
Mean ttft: 538.49 ms
Median ttft: 95.66 ms
P99 ttft: 13937.86 ms
Mean ttst: 1218.72 ms
Median ttst: 152.57 ms
P99 ttst: 14241.30 ms
Mean TPOT: 91.83 ms
Median TPOT: 16.63 ms
P99 TPOT: 363.37 ms
Effectuer un nettoyage
Déconnectez-vous du TPU :
$ (vm) exit
Supprimez le TPU :
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --force \ --async
Supprimez les buckets et leur contenu :
export CHKPT_BUCKET=gs://your-checkpoint-bucket export BASE_OUTPUT_DIRECTORY=gs://your-output-dir export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data gcloud storage rm -r ${CHKPT_BUCKET} gcloud storage rm -r ${BASE_OUTPUT_DIRECTORY} gcloud storage rm -r ${CONVERTED_CHECKPOINT_PATH} gcloud storage rm -r ${MAXTEXT_BUCKET_UNSCANNED}