v5e Cloud TPU VM での JetStream PyTorch 推論


JetStream は、XLA デバイス(TPU)での大規模言語モデル(LLM)推論向けのスループットとメモリ最適化エンジンです。

始める前に

Cloud TPU 環境の設定の手順に沿って、Google Cloud プロジェクトを作成し、TPU API を有効にして、TPU CLI をインストールし、TPU 割り当てをリクエストします。

CreateNode API を使用して Cloud TPU を作成するの手順に沿って、--accelerator-typev5litepod-8 に設定して TPU VM を作成します。

JetStream リポジトリのクローンを作成し、依存関係をインストールする

  1. SSH を使用して TPU VM に接続する

    • ${TPU_NAME} を TPU の名前に設定します。
    • ${PROJECT} を Google Cloud プロジェクトに設定する
    • ${ZONE} を、TPU を作成する Google Cloud ゾーンに設定します。
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. JetStream リポジトリのクローンを作成する

       git clone https://github.com/google/jetstream-pytorch.git
    

    (省略可)venv または conda を使用して仮想 Python 環境を作成し、有効にします。

  3. インストール スクリプトを実行します。

       cd jetstream-pytorch
       source install_everything.sh
    

重みをダウンロードして変換する

  1. GitHub から公式の Llama 重みをダウンロードします。

  2. 重みを変換します。

    • ${IN_CKPOINT} を Llama の重みを含む場所に設定します。
    • ${OUT_CKPOINT} をロケーション書き込みチェックポイントに設定する
    export input_ckpt_dir=${IN_CKPOINT}
    export output_ckpt_dir=${OUT_CKPOINT}
    export quantize=True
    python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
    

JetStream PyTorch エンジンをローカルで実行する

JetStream PyTorch エンジンをローカルで実行するには、トークナイザ パスを設定します。

export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama

Llama 7B で JetStream PyTorch エンジンを実行する

python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

Llama 13b で JetStream PyTorch エンジンを実行する

python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path

JetStream サーバーを実行する

python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --platform=tpu=8

注: --platform=tpu= パラメータでは TPU デバイスの数を指定する必要があります(v4-8 の場合は 4、v5lite-8 の場合は 8)。例: --platform=tpu=8

run_server.py を実行すると、JetStream PyTorch エンジンが gRPC 呼び出しを受信できるようになります。

ベンチマークを実行する

install_everything.sh の実行時にダウンロードした deps/JetStream フォルダに移動します。

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

詳細については、deps/JetStream/benchmarks/README.md をご覧ください。

一般的なエラー

Unexpected keyword argument 'device' エラーが発生した場合は、次の手順をお試しください。

  • jaxjaxlib の依存関係をアンインストールする
  • source install_everything.sh を使用して再インストールする

Out of memory エラーが発生した場合は、次の手順をお試しください。

  • バッチサイズを小さくする
  • 量子化を使用する

クリーンアップ

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

  1. GitHub リポジトリをクリーンアップする

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Python 仮想環境をクリーンアップする

    rm -rf .env
    
  3. TPU リソースの削除

    詳細については、TPU リソースを削除するをご覧ください。