始める前に
Cloud TPU 環境の設定の手順に沿って、Google Cloud プロジェクトを作成し、TPU API を有効にして、TPU CLI をインストールし、TPU 割り当てをリクエストします。
CreateNode API を使用して Cloud TPU を作成するの手順に沿って、--accelerator-type
を v5litepod-8
に設定して TPU VM を作成します。
JetStream リポジトリのクローンを作成し、依存関係をインストールする
SSH を使用して TPU VM に接続する
- ${TPU_NAME} を TPU の名前に設定します。
- ${PROJECT} を Google Cloud プロジェクトに設定する
- ${ZONE} を、TPU を作成する Google Cloud ゾーンに設定します。
gcloud compute config-ssh gcloud compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT} --zone ${ZONE}
JetStream リポジトリのクローンを作成する
git clone https://github.com/google/jetstream-pytorch.git
(省略可)
venv
またはconda
を使用して仮想 Python 環境を作成し、有効にします。インストール スクリプトを実行します。
cd jetstream-pytorch source install_everything.sh
重みをダウンロードして変換する
GitHub から公式の Llama 重みをダウンロードします。
重みを変換します。
- ${IN_CKPOINT} を Llama の重みを含む場所に設定します。
- ${OUT_CKPOINT} をロケーション書き込みチェックポイントに設定する
export input_ckpt_dir=${IN_CKPOINT} export output_ckpt_dir=${OUT_CKPOINT} export quantize=True python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
JetStream PyTorch エンジンをローカルで実行する
JetStream PyTorch エンジンをローカルで実行するには、トークナイザ パスを設定します。
export tokenizer_path=${TOKENIZER_PATH} # tokenizer model file path from meta-llama
Llama 7B で JetStream PyTorch エンジンを実行する
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
Llama 13b で JetStream PyTorch エンジンを実行する
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path
JetStream サーバーを実行する
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8
注: --platform=tpu=
パラメータでは TPU デバイスの数を指定する必要があります(v4-8
の場合は 4、v5lite-8
の場合は 8)。例: --platform=tpu=8
run_server.py
を実行すると、JetStream PyTorch エンジンが gRPC 呼び出しを受信できるようになります。
ベンチマークを実行する
install_everything.sh
の実行時にダウンロードした deps/JetStream
フォルダに移動します。
cd deps/JetStream
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs
詳細については、deps/JetStream/benchmarks/README.md
をご覧ください。
一般的なエラー
Unexpected keyword argument 'device'
エラーが発生した場合は、次の手順をお試しください。
jax
とjaxlib
の依存関係をアンインストールするsource install_everything.sh
を使用して再インストールする
Out of memory
エラーが発生した場合は、次の手順をお試しください。
- バッチサイズを小さくする
- 量子化を使用する
クリーンアップ
このチュートリアルで使用したリソースについて、Google Cloud アカウントに課金されないようにするには、リソースを含むプロジェクトを削除するか、プロジェクトを維持して個々のリソースを削除します。
GitHub リポジトリをクリーンアップする
# Clean up the JetStream repository rm -rf JetStream # Clean up the xla repository rm -rf xla
Python 仮想環境をクリーンアップする
rm -rf .env
TPU リソースの削除
詳細については、TPU リソースを削除するをご覧ください。