Inférence JetStream MaxText sur des VM TPU v6e

Ce tutoriel explique comment utiliser JetStream pour mettre en service des modèles MaxText sur TPU v6e. JetStream est un moteur optimisé pour le débit et la mémoire pour l'inférence de grands modèles de langage (LLM) sur les appareils XLA (TPU). Dans ce tutoriel, vous allez exécuter le benchmark d'inférence pour le modèle Llama2-7B.

Avant de commencer

Préparez-vous à provisionner un TPU v6e avec quatre puces :

  1. Suivez le guide Configurer l'environnement Cloud TPU pour configurer un projet Google Cloud , configurer la Google Cloud CLI, activer l'API Cloud TPU et vous assurer d'avoir accès à Cloud TPU.

  2. Authentifiez-vous avec Google Cloud et configurez le projet et la zone par défaut pour Google Cloud CLI.

    gcloud auth login
    gcloud config set project PROJECT_ID
    gcloud config set compute/zone ZONE

Sécuriser la capacité

Lorsque vous êtes prêt à sécuriser la capacité des TPU, consultez Quotas Cloud TPU pour en savoir plus sur les quotas Cloud TPU. Si vous avez d'autres questions sur la sécurisation de la capacité, contactez votre équipe commerciale ou l'équipe chargée du compte Cloud TPU.

Provisionner l'environnement Cloud TPU

Vous pouvez provisionner des VM TPU avec GKE, avec GKE et XPK, ou en tant que ressources en file d'attente.

Prérequis

  • Vérifiez que votre projet dispose d'un quota TPUS_PER_TPU_FAMILY suffisant, qui spécifie le nombre maximal de puces auxquelles vous pouvez accéder dans votre projetGoogle Cloud .
  • Vérifiez que votre projet dispose d'un quota TPU suffisant pour :
    • Quota de VM TPU
    • Quota d'adresses IP
    • Quota Hyperdisk Balanced
  • Autorisations liées au projet utilisateur

Créer des variables d'environnement

Dans Cloud Shell, créez les variables d'environnement suivantes :

export PROJECT_ID=your-project-id
export TPU_NAME=your-tpu-name
export ZONE=us-east5-b
export ACCELERATOR_TYPE=v6e-4
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=your-queued-resource-id

Descriptions des variables d'environnement

Variable Description
PROJECT_ID ID de votre projet Google Cloud . Utilisez un projet existant ou créez-en un.
TPU_NAME Nom du TPU.
ZONE Zone dans laquelle créer la VM TPU. Pour en savoir plus sur les zones compatibles, consultez Régions et zones TPU.
ACCELERATOR_TYPE Le type d'accélérateur spécifie la version et la taille du Cloud TPU que vous souhaitez créer. Pour en savoir plus sur les types d'accélérateurs compatibles avec chaque version de TPU, consultez Versions de TPU.
RUNTIME_VERSION Version logicielle de Cloud TPU.
SERVICE_ACCOUNT Adresse e-mail de votre compte de service. Pour le trouver, accédez à la page Comptes de service dans la console Google Cloud .

Par exemple : tpu-service-account@PROJECT_ID.iam.

QUEUED_RESOURCE_ID ID de texte attribué par l'utilisateur de la demande de ressource en file d'attente.

Provisionner un TPU v6e

Exécutez la commande suivante pour provisionner un TPU v6e :

gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
    --node-id=${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --accelerator-type=${ACCELERATOR_TYPE} \
    --runtime-version=${RUNTIME_VERSION} \
    --service-account=${SERVICE_ACCOUNT}

Utilisez les commandes list ou describe pour interroger l'état de votre ressource en file d'attente.

gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
    --project ${PROJECT_ID} --zone ${ZONE}

Pour en savoir plus sur l'état des demandes de ressources en file d'attente, consultez Gérer des ressources en file d'attente.

Se connecter au TPU à l'aide de SSH

   gcloud compute tpus tpu-vm ssh ${TPU_NAME}

Une fois que vous êtes connecté au TPU, vous pouvez exécuter le benchmark d'inférence.

Configurer votre environnement de VM TPU

  1. Créez un répertoire pour exécuter le benchmark d'inférence :

    export MAIN_DIR=your-main-directory
    mkdir -p ${MAIN_DIR}
  2. Configurez un environnement virtuel Python :

    cd ${MAIN_DIR}
    sudo apt update
    sudo apt install python3.10 python3.10-venv
    python3.10 -m venv venv
    source venv/bin/activate
  3. Installez Git Large File Storage (LFS) (pour les données OpenOrca) :

    sudo apt-get install git-lfs
    git lfs install
  4. Clonez et installez JetStream :

    cd $MAIN_DIR
    git clone https://github.com/google/JetStream.git
    cd JetStream
    git checkout main
    pip install -e .
    cd benchmarks
    pip install -r requirements.in
  5. Configurez MaxText :

    cd $MAIN_DIR
    git clone https://github.com/google/maxtext.git
    cd maxtext
    git checkout main
    bash setup.sh
    pip install torch --index-url https://download.pytorch.org/whl/cpu
  6. Demandez l'accès aux modèles Llama pour obtenir une clé de téléchargement de Meta pour le modèle Llama 2.

  7. Clonez le dépôt Llama :

    cd $MAIN_DIR
    git clone https://github.com/meta-llama/llama
    cd llama
  8. Exécutez bash download.sh. Lorsque vous y êtes invité, fournissez votre clé de téléchargement. Ce script crée un répertoire llama-2-7b dans votre répertoire llama.

    bash download.sh
  9. Créez des buckets de stockage :

    export CHKPT_BUCKET=gs://your-checkpoint-bucket
    export BASE_OUTPUT_DIRECTORY=gs://your-output-dir
    export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints
    export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data
    gcloud storage buckets create ${CHKPT_BUCKET}
    gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY}
    gcloud storage buckets create ${CONVERTED_CHECKPOINT_PATH}
    gcloud storage buckets create ${MAXTEXT_BUCKET_UNSCANNED}
    gcloud storage cp --recursive llama-2-7b/* ${CHKPT_BUCKET}

Effectuer la conversion des points de contrôle

  1. Effectuer la conversion en points de contrôle analysés :

    cd $MAIN_DIR/maxtext
    python3 -m MaxText.llama_or_mistral_ckpt \
        --base-model-path $MAIN_DIR/llama/llama-2-7b \
        --model-size llama2-7b \
        --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
  2. Convertir en points de contrôle non analysés :

    export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items
    export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint
    python3 -m MaxText.generate_param_only_checkpoint \
        MaxText/configs/base.yml \
        base_output_directory=${MAXTEXT_BUCKET_UNSCANNED} \
        load_parameters_path=${CONVERTED_CHECKPOINT} \
        run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} \
        model_name='llama2-7b' \
        force_unroll=true

Effectuer une inférence

  1. Exécutez un test de validation :

    export UNSCANNED_CKPT_PATH=${MAXTEXT_BUCKET_UNSCANNED}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items
    python3 -m MaxText.decode \
        MaxText/configs/base.yml \
        load_parameters_path=${UNSCANNED_CKPT_PATH} \
        run_name=runner_decode_unscanned_${idx} \
        base_output_directory=${BASE_OUTPUT_DIRECTORY} \
        per_device_batch_size=1 \
        model_name='llama2-7b' \
        ici_autoregressive_parallelism=4 \
        max_prefill_predict_length=4 \
        max_target_length=16 \
        prompt="I love to" \
        attention=dot_product \
        scan_layers=false
  2. Exécutez le serveur dans votre terminal actuel :

    export TOKENIZER_PATH=assets/tokenizer.llama2
    export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
    export MAX_PREFILL_PREDICT_LENGTH=1024
    export MAX_TARGET_LENGTH=2048
    export MODEL_NAME=llama2-7b
    export ICI_FSDP_PARALLELISM=1
    export ICI_AUTOREGRESSIVE_PARALLELISM=1
    export ICI_TENSOR_PARALLELISM=-1
    export SCAN_LAYERS=false
    export WEIGHT_DTYPE=bfloat16
    export PER_DEVICE_BATCH_SIZE=11
    
    cd $MAIN_DIR/maxtext
    python3 -m MaxText.maxengine_server \
        MaxText/configs/base.yml \
        tokenizer_path=${TOKENIZER_PATH} \
        load_parameters_path=${LOAD_PARAMETERS_PATH} \
        max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
        max_target_length=${MAX_TARGET_LENGTH} \
        model_name=${MODEL_NAME} \
        ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
        ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
        ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
        scan_layers=${SCAN_LAYERS} \
        weight_dtype=${WEIGHT_DTYPE} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
  3. Ouvrez une nouvelle fenêtre de terminal, connectez-vous au TPU et basculez sur le même environnement virtuel que celui que vous avez utilisé dans la première fenêtre de terminal :

    source venv/bin/activate
    
  4. Exécutez les commandes suivantes pour exécuter le benchmark JetStream.

    export MAIN_DIR=your-main-directory
    cd $MAIN_DIR
    
    python JetStream/benchmarks/benchmark_serving.py \
        --tokenizer $MAIN_DIR/maxtext/assets/tokenizer.llama2 \
        --warmup-mode sampled \
        --save-result \
        --save-request-outputs \
        --request-outputs-file-path outputs.json \
        --num-prompts 1000 \
        --max-output-length 1024 \
        --dataset openorca \
        --dataset-path $MAIN_DIR/JetStream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl

Résultats

Le résultat suivant a été généré lors de l'exécution du benchmark avec v6e-8. Les résultats varient en fonction du matériel, des logiciels, du modèle et de la mise en réseau.

Mean output size: 929.5959798994975
Median output size: 1026.0
P99 output size: 1026.0
Successful requests: 995
Benchmark duration: 195.533269 s
Total input tokens: 217011
Total generated tokens: 924948
Request throughput: 5.09 requests/s
Input token throughput: 1109.84 tokens/s
Output token throughput: 4730.39 tokens/s
Overall token throughput: 5840.23 tokens/s
Mean ttft: 538.49 ms
Median ttft: 95.66 ms
P99 ttft: 13937.86 ms
Mean ttst: 1218.72 ms
Median ttst: 152.57 ms
P99 ttst: 14241.30 ms
Mean TPOT: 91.83 ms
Median TPOT: 16.63 ms
P99 TPOT: 363.37 ms

Effectuer un nettoyage

  1. Déconnectez-vous du TPU :

    $ (vm) exit
  2. Supprimez le TPU :

    gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
        --project ${PROJECT_ID} \
        --zone ${ZONE} \
        --force \
        --async
  3. Supprimez les buckets et leur contenu :

    export CHKPT_BUCKET=gs://your-checkpoint-bucket
    export BASE_OUTPUT_DIRECTORY=gs://your-output-dir
    export CONVERTED_CHECKPOINT_PATH=gs://bucket-to-store-converted-checkpoints
    export MAXTEXT_BUCKET_UNSCANNED=gs://bucket-to-store-unscanned-data
    gcloud storage rm -r ${CHKPT_BUCKET}
    gcloud storage rm -r ${BASE_OUTPUT_DIRECTORY}
    gcloud storage rm -r ${CONVERTED_CHECKPOINT_PATH}
    gcloud storage rm -r ${MAXTEXT_BUCKET_UNSCANNED}