v5e Cloud TPU VM での JetStream PyTorch 推論


JetStream は、XLA デバイス(TPU)での大規模言語モデル(LLM)推論用にスループットとメモリが最適化されたエンジンです。

始める前に

Cloud TPU 環境を設定するの手順に沿って、 Google Cloud プロジェクトを作成し、TPU API を有効にして、TPU CLI をインストールするとともに、TPU 割り当てをリクエストします。

CreateNode API を使用して Cloud TPU を作成するの手順に沿って、--accelerator-typev5litepod-8 に設定する TPU VM を作成します。

JetStream リポジトリのクローンを作成して依存関係をインストールする

  1. SSH を使用して TPU VM に接続する

    • ${TPU_NAME} を TPU の名前に設定します。
    • ${PROJECT} を Google Cloud プロジェクトに設定します。
    • ${ZONE} を TPU を作成する Google Cloud ゾーンに設定します。
      gcloud compute config-ssh
      gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
    
  2. jetstream-pytorch コードを取得する bash git clone https://github.com/google/jetstream-pytorch.git git checkout jetstream-v0.2.4

(省略可)venv または conda を使用して仮想環境を作成し、有効にします。

sudo apt install python3.10-venv
python -m venv venv
source venv/bin/activate
  1. インストール スクリプトを実行します。
cd jetstream-pytorch
source install_everything.sh

jetstream pytorch を実行する

サポートされているモデルを一覧表示する

jpt list

サポートされているモデルとバリエーションのリストが出力されます。

meta-llama/Llama-2-7b-chat-hf
meta-llama/Llama-2-7b-hf
meta-llama/Llama-2-13b-chat-hf
meta-llama/Llama-2-13b-hf
meta-llama/Llama-2-70b-hf
meta-llama/Llama-2-70b-chat-hf
meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3-8B-Instruct
meta-llama/Meta-Llama-3-70B
meta-llama/Meta-Llama-3-70B-Instruct
google/gemma-2b
google/gemma-2b-it
google/gemma-7b
google/gemma-7b-it
mistralai/Mixtral-8x7B-v0.1
mistralai/Mixtral-8x7B-Instruct-v0.1

1 つのモデルで jetstream-pytorch サーバーを実行するには: bash jpt serve --model_id meta-llama/Llama-2-7b-chat-hf

このモデルを初めて実行すると、jpt serve コマンドは HuggingFace から重みをダウンロードしようとします。この場合、HuggingFace で認証する必要があります。

認証するには、huggingface-cli login を実行してアクセス トークンを設定するか、--hf_token フラグを使用して HuggingFace アクセス トークンを jpt serve コマンドに渡します。

jpt serve --model_id meta-llama/Llama-2-7b-chat-hf --hf_token=...

HuggingFace アクセス トークンの詳細については、アクセス トークンをご覧ください。

HuggingFace Hub を使用してログインするには、次のコマンドを実行し、プロンプトに沿って操作します。

pip install -U "huggingface_hub[cli]"
huggingface-cli login

重みがダウンロードされたら、--hf_token フラグを指定する必要はありません。

int8 量子化でこのモデルを実行するには、--quantize_weights=1 を追加します。量子化は、重みが読み込まれる際にフリート上で行われます。

HuggingFace からダウンロードされた重みは、デフォルトでは jpt を実行するディレクトリの checkpoints フォルダに保存されます。--working_dir フラグを使用してディレクトリを変更することもできます。

独自のチェックポイントを使用する場合は、checkpoints/<org>/<model>/hf_original ディレクトリ(または --working_dir 内の対応するサブディレクトリ)に配置します。たとえば、Llama2-7b チェックポイントは checkpoints/meta-llama/Llama-2-7b-hf/hf_original/*.safetensors にあります。これらのファイルは、HuggingFace 形式の変更済み重みで置き換えることができます。

ベンチマークを実行する

install_everything.sh の実行時にダウンロードされた deps/JetStream フォルダに移動します。

cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs

詳細については、deps/JetStream/benchmarks/README.md をご覧ください。

代表的なエラー

Unexpected keyword argument 'device' エラーが発生した場合は、次の手順をお試しください。

  • jaxjaxlib の依存関係をアンインストールする
  • source install_everything.sh を使用して再インストールする

Out of memory エラーが発生した場合は、次の手順をお試しください。

  • 使用するバッチサイズを縮小する
  • 量子化を使用する

クリーンアップ

このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。

  1. GitHub リポジトリをクリーンアップする

      # Clean up the JetStream repository
      rm -rf JetStream
    
      # Clean up the xla repository
      rm -rf xla
    
  2. Python 仮想環境をクリーンアップする

    rm -rf .env
    
  3. TPU リソースを削除する

    詳細については、TPU リソースを削除するをご覧ください。