Inférence JetStream PyTorch sur une VM Cloud TPU v5e


JetStream est un moteur optimisé en termes de débit et de mémoire pour les grands modèles de langage (LLM) sur les appareils XLA (TPU).

Avant de commencer

Suivez la procédure décrite dans Configurer l'environnement Cloud TPU pour créer un projet Google Cloud, activer l'API TPU, installer la CLI TPU et demander un quota TPU.

Suivez les étapes de la section Créer une ressource Cloud TPU à l'aide de l'API CreateNode pour créez une VM TPU en définissant --accelerator-type sur v5litepod-8.

Cloner le dépôt JetStream et installer les dépendances

  1. Se connecter à votre VM TPU à l'aide de SSH

    • Définissez ${TPU_NAME} sur le nom de votre TPU.
    • Définissez ${PROJECT} sur votre projet Google Cloud
    • Définissez ${ZONE} sur la zone Google Cloud dans laquelle créer vos TPU
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. Cloner le dépôt JetStream

       git clone https://github.com/google/jetstream-pytorch.git
    

    (Facultatif) Créez un environnement Python virtuel à l'aide de venv ou conda, puis l'activer.

  3. Exécuter le script d'installation

       cd jetstream-pytorch
       source install_everything.sh
    

Télécharger et convertir des pondérations

  1. Téléchargez les poids officiels de Llama sur GitHub.

  2. Convertissez les pondérations.

    • Définissez ${IN_CKPOINT} sur l'emplacement contenant les pondérations du lama
    • Définir ${OUT_CKPOINT} sur un point de contrôle en écriture pour un établissement
    export input_ckpt_dir=${IN_CKPOINT} 
    export output_ckpt_dir=${OUT_CKPOINT} 
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

Exécuter le moteur PyTorch JetStream en local

Pour exécuter le moteur JetStream PyTorch localement, définissez le chemin d'accès de la fonction de tokenisation:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Faire fonctionner le moteur JetStream PyTorch avec Llama 7B

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Faire fonctionner le moteur JetStream PyTorch avec Llama 13b

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Exécuter le serveur JetStream

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

REMARQUE: Le paramètre --platform=tpu= doit spécifier le nombre d'appareils TPU. (soit 4 pour v4-8 et 8 pour v5lite-8). Exemple : --platform=tpu=8.

Après l'exécution de run_server.py, le moteur JetStream PyTorch est prêt à recevoir des appels gRPC.

Exécuter des benchmarks

Accédez au dossier deps/JetStream qui a été téléchargé lors de l'exécution install_everything.sh

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

Pour en savoir plus, consultez deps/JetStream/benchmarks/README.md.

Erreurs courantes

Si une erreur Unexpected keyword argument 'device' s'affiche, essayez ce qui suit:

  • Désinstaller les dépendances jax et jaxlib
  • Réinstaller à l'aide de source install_everything.sh

Si une erreur Out of memory s'affiche, essayez ce qui suit:

  • Utiliser une taille de lot plus petite
  • Utiliser la quantification

Effectuer un nettoyage

Pour éviter que les ressources utilisées lors de ce tutoriel soient facturées sur votre compte Google Cloud, supprimez le projet contenant les ressources, ou conservez le projet et supprimez les ressources individuelles.

  1. Nettoyer les dépôts GitHub

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Nettoyer l'environnement virtuel Python

    rm -rf .env
    
  3. Supprimer vos ressources TPU

    Pour en savoir plus, consultez Supprimer vos ressources TPU.