JetStream-PyTorch-Inferenz auf v5e-Cloud TPU-VM


JetStream ist eine durchsatz- und speicheroptimierte Engine für die LLM-Inferenz (Large Language Model) auf XLA-Geräten (TPUs).

Hinweise

Führen Sie die Schritte unter Cloud TPU-Umgebung einrichten aus, um ein Google Cloud-Projekt zu erstellen, die TPU API zu aktivieren, die TPU CLI zu installieren und ein TPU-Kontingent anzufordern.

Folgen Sie der Anleitung unter Cloud TPU mit der CreateNode API erstellen, um eine TPU-VM mit der Einstellung --accelerator-type zu v5litepod-8 zu erstellen.

JetStream-Repository klonen und Abhängigkeiten installieren

  1. Über SSH eine Verbindung zur TPU-VM herstellen

    • Legen Sie ${TPU_NAME} auf den Namen Ihrer TPU fest.
    • Legen Sie ${PROJECT} auf Ihr Google Cloud-Projekt fest.
    • Legen Sie ${ZONE} auf die Google Cloud-Zone fest, in der die TPUs erstellt werden sollen.
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. JetStream-Repository klonen

       git clone https://github.com/google/jetstream-pytorch.git
    

    Optional: Erstellen Sie mit venv oder conda eine virtuelle Python-Umgebung und aktivieren Sie sie.

  3. Installationsskript ausführen

       cd jetstream-pytorch
       source install_everything.sh
    

Gewichte herunterladen und konvertieren

  1. Laden Sie die offiziellen Llama-Gewichte von GitHub herunter.

  2. Wandeln Sie die Gewichte um.

    • Legen Sie ${IN_CKPOINT} auf den Speicherort fest, der die Lama-Gewichte enthält.
    • Legen Sie ${OUT_CKPOINT} auf einen Speicherort für Schreib-Prüfpunkte fest.
    export input_ckpt_dir=${IN_CKPOINT} 
    export output_ckpt_dir=${OUT_CKPOINT} 
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

JetStream PyTorch-Engine lokal ausführen

Wenn Sie die JetStream PyTorch-Engine lokal ausführen möchten, legen Sie den Pfad zum Tokenisierer fest:

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

JetStream-PyTorch-Engine mit Llama 7B ausführen

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

JetStream-PyTorch-Engine mit Llama 13b ausführen

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

JetStream-Server ausführen

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

HINWEIS: Der Parameter --platform=tpu= muss die Anzahl der TPU-Geräte angeben (4 für v4-8 und 8 für v5lite-8). Beispiel: --platform=tpu=8.

Nach dem Ausführen von run_server.py ist die JetStream PyTorch-Engine bereit, gRPC-Aufrufe zu empfangen.

Benchmarks ausführen

Wechseln Sie zum Ordner deps/JetStream, der beim Ausführen von install_everything.sh heruntergeladen wurde.

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

Weitere Informationen finden Sie unter deps/JetStream/benchmarks/README.md.

Typische Fehler

Wenn du einen Unexpected keyword argument 'device'-Fehler erhältst, versuche Folgendes:

  • Abhängigkeiten von jax und jaxlib deinstallieren
  • Mit source install_everything.sh neu installieren

Wenn du einen Out of memory-Fehler erhältst, versuche Folgendes:

  • Kleinere Batchgröße verwenden
  • Quantisierung verwenden

Bereinigen

Damit Ihrem Google Cloud-Konto die in dieser Anleitung verwendeten Ressourcen nicht in Rechnung gestellt werden, löschen Sie entweder das Projekt, das die Ressourcen enthält, oder Sie behalten das Projekt und löschen die einzelnen Ressourcen.

  1. GitHub-Repositories bereinigen

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Virtuelle Python-Umgebung bereinigen

    rm -rf .env
    
  3. TPU-Ressourcen löschen

    Weitere Informationen finden Sie unter TPU-Ressourcen löschen.