Cloud TPU v5e での SAX
SAX クラスタ(SAX セル)
SAX 管理サーバーと SAX モデルサーバーは、SAX クラスタを実行する 2 つの重要なコンポーネントです。
SAX 管理サーバー
SAX 管理サーバーは、SAX クラスタ内のすべての SAX モデルサーバーをモニタリングし、調整します。SAX クラスタでは、複数の SAX 管理サーバーを起動できます。この場合、リーダーの選択で SAX 管理サーバーのうち 1 つだけがアクティブになり、その他はスタンバイ サーバーになります。アクティブな管理サーバーに障害が発生すると、スタンバイ管理サーバーがアクティブになります。アクティブな SAX 管理サーバーは、使用可能な SAX モデルサーバーにモデルレプリカと推論リクエストを割り当てます。
SAX 管理ストレージ バケット
各 SAX クラスタには、SAX 管理サーバーと SAX モデルサーバーの構成とロケーションを SAX クラスタに保存する Cloud Storage バケットが必要です。
SAX モデルサーバー
SAX モデルサーバーは、モデル チェックポイントを読み込み、GSPMD を使用して推論を実行します。SAX モデルサーバーは、1 つの TPU VM ワーカーで動作します。単一ホストの TPU モデルを提供するには、単一ホストの TPU VM 上に単一の SAX モデルサーバーが必要です。マルチホスト TPU モデルを提供するには、マルチホスト TPU スライス上に SAX モデルサーバー グループが必要です。マルチホスト モデルの提供は現在利用できませんが、このドキュメントではプレビュー用に 175B テストモデルの例を提供しています。
SAX モデルの提供
次のセクションでは、SAX を使用して言語モデルを提供するワークフローについて説明します。これは、単一ホストのモデル提供の例として GPT-J 6B モデルを使用しています。
始める前に、TPU VM に Cloud TPU SAX Docker イメージをインストールします。
sudo usermod -a -G docker ${USER} newgrp docker gcloud auth configure-docker us-docker.pkg.dev SAX_ADMIN_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server" SAX_MODEL_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server" SAX_UTIL_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-util" SAX_VERSION=v1.0.0 export SAX_ADMIN_SERVER_IMAGE_URL=${SAX_ADMIN_SERVER_IMAGE_NAME}:${SAX_VERSION} export SAX_MODEL_SERVER_IMAGE_URL=${SAX_MODEL_SERVER_IMAGE_NAME}:${SAX_VERSION} export SAX_UTIL_IMAGE_URL="${SAX_UTIL_IMAGE_NAME}:${sax_version}" docker pull ${SAX_ADMIN_SERVER_IMAGE_URL} docker pull ${SAX_MODEL_SERVER_IMAGE_URL} docker pull ${SAX_UTIL_IMAGE_URL}
後で使用するその他の変数を設定します。
export SAX_ADMIN_SERVER_DOCKER_NAME="sax-admin-server" export SAX_MODEL_SERVER_DOCKER_NAME="sax-model-server" export SAX_CELL="/sax/test"
GPT-J 6B 単一ホストモデル提供の例
単一ホストモデルの提供は、単一ホスト TPU スライス(v5litepod-1、v5litepod-4、v5litepod-8)に適用できます。
SAX クラスタを作成する
SAX クラスタ用の Cloud Storage ストレージ バケットを作成します。
SAX_ADMIN_STORAGE_BUCKET=${your_admin_storage_bucket} gcloud storage buckets create gs://${SAX_ADMIN_STORAGE_BUCKET} \ --project=${PROJECT_ID}
チェックポイントを保存するために、別の Cloud Storage ストレージ バケットが必要になる場合があります。
SAX_DATA_STORAGE_BUCKET=${your_data_storage_bucket}
ターミナルの TPU VM に SSH 接続し、SAX 管理サーバーを起動します。
docker run \ --name ${SAX_ADMIN_SERVER_DOCKER_NAME} \ -it \ -d \ --rm \ --network host \ --env GSBUCKET=${SAX_ADMIN_STORAGE_BUCKET} \ ${SAX_ADMIN_SERVER_IMAGE_URL}
Docker ログは次の方法で確認できます。
docker logs -f ${SAX_ADMIN_SERVER_DOCKER_NAME}
ログの出力は次のようになります。
I0829 01:22:31.184198 7 config.go:111] Creating config fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.347883 7 config.go:115] Created config fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.360837 24 admin_server.go:44] Starting the server I0829 01:22:31.361420 24 ipaddr.go:39] Skipping non-global IP address 127.0.0.1/8. I0829 01:22:31.361455 24 ipaddr.go:39] Skipping non-global IP address ::1/128. I0829 01:22:31.361462 24 ipaddr.go:39] Skipping non-global IP address fe80::4001:aff:fe8e:fc8/64. I0829 01:22:31.361469 24 ipaddr.go:39] Skipping non-global IP address fe80::42:bfff:fef9:1bd3/64. I0829 01:22:31.361474 24 ipaddr.go:39] Skipping non-global IP address fe80::20fb:c3ff:fe5b:baac/64. I0829 01:22:31.361482 24 ipaddr.go:56] IPNet address 10.142.15.200 I0829 01:22:31.361488 24 ipaddr.go:56] IPNet address 172.17.0.1 I0829 01:22:31.456952 24 admin.go:305] Loaded config: fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.609323 24 addr.go:105] SetAddr /gcs/test_sax_admin/sax-root/sax/test/location.proto "10.142.15.200:10000" I0829 01:22:31.656021 24 admin.go:325] Updated config: fs_root: "gs://test_sax_admin/sax-fs-root" I0829 01:22:31.773245 24 mgr.go:781] Loaded manager state I0829 01:22:31.773260 24 mgr.go:784] Refreshing manager state every 10s I0829 01:22:31.773285 24 admin.go:350] Starting the server on port 10000 I0829 01:22:31.773292 24 cloud.go:506] Starting the HTTP server on port 8080
単一ホストの SAX モデルサーバーを SAX クラスタに起動します。
この時点で、SAX クラスタには SAX 管理サーバーのみが含まれています。第 2 のターミナルで SSH を介して TPU VM に接続し、SAX クラスタで SAX モデルサーバーを起動できます。
docker run \ --privileged \ -it \ -d \ --rm \ --network host \ --name ${SAX_MODEL_SERVER_DOCKER_NAME} \ --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ ${SAX_MODEL_SERVER_IMAGE_URL} \ --sax_cell=${SAX_CELL} \ --port=10001 \ --platform_chip=tpuv4 \ --platform_topology=1x1
モデルのチェックポイントを変換します。
PyTorch と Transformer をインストールして、EleutherAI から GPT-J チェックポイントをダウンロードする必要があります。
pip3 install accelerate pip3 install torch pip3 install transformers
チェックポイントを SAX チェックポイントに変換するには、
paxml
をインストールする必要があります。pip3 install paxml==1.1.0
次のスクリプトは、GPT-J チェックポイントを SAX チェックポイントに変換します。
python3 -m convert_gptj_ckpt --base EleutherAI/gpt-j-6b --pax pax_6b
変換後:
ls checkpoint_00000000/
commit_success ファイルを作成し、サブディレクトリに配置する必要があります。
gsutil -m cp -r checkpoint_00000000 ${CHECKPOINT_PATH} touch commit_success.txt gsutil cp commit_success.txt ${CHECKPOINT_PATH}/ gsutil cp commit_success.txt ${CHECKPOINT_PATH}/metadata/ gsutil cp commit_success.txt ${CHECKPOINT_PATH}/state/
SAX クラスタにモデルを公開する
これで、前のステップで変換したチェックポイントを使用して GPT-J を公開できるようになりました。
MODEL_NAME=gptjtokenizedbf16bs32 MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4TokenizedBF16BS32 REPLICA=1
GPT-J(およびその後の手順)を公開するには、SSH を使用して第 3 のターミナルで TPU VM に接続します。
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ publish \ ${SAX_CELL}/${MODEL_NAME} \ ${MODEL_CONFIG_PATH} \ ${CHECKPOINT_PATH} \ ${REPLICA}
モデルが正常に読み込まれたことを示す次のような出力が表示されるまで、モデルサーバーの Docker ログの多くのアクティビティが表示されます。
I0829 01:33:49.287459 139865140229696 servable_model.py:697] loading completed.
推論結果を生成する
GPT-J では、入出力はカンマ区切りのトークン ID 文字列としてフォーマットされる必要があります。テキスト入力はトークン化する必要があります。
TEXT = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction\:\nSummarize the following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly international in scope on Tuesday. We're visiting Italy, Russia, the United Arab Emirates, and the Himalayan Mountains. Find out who's attempting to circumnavigate the globe in a plane powered partially by the sun, and explore the mysterious appearance of craters in northern Asia. You'll also get a view of Mount Everest that was previously reserved for climbers. On this page you will find today's show Transcript and a place for you to request to be on the CNN Student News Roll Call. TRANSCRIPT . Click here to access the transcript of today's CNN Student News program. Please note that there may be a delay between the time when the video is available and when the transcript is published. CNN Student News is created by a team of journalists who consider the Common Core State Standards, national standards in different subject areas, and state standards when producing the show. ROLL CALL . For a chance to be mentioned on the next CNN Student News, comment on the bottom of this page with your school name, mascot, city and state. We will be selecting schools from the comments of the previous show. You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call! Thank you for using CNN Student News!\n\n### Response\:
トークン ID 文字列は、EleutherAI/gpt-j-6b トークナイザを使用して取得できます。
from transformers import GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-j-6b") :
入力テキストをトークン化します。
encoded_example = tokenizer(TEXT) input_ids = encoded_example.input_ids INPUT_STR = ",".join([str(input_id) for input_id in input_ids])
次のようなトークン ID 文字列が想定されます。
>>> INPUT_STR '21106,318,281,12064,326,8477,257,4876,11,20312,351,281,5128,326,3769,2252,4732,13,19430,257,2882,326,20431,32543,262,2581,13,198,198,21017,46486,25,198,13065,3876,1096,262,1708,1705,2708,25,198,198,21017,23412,25,198,16192,838,11,1853,764,775,821,4988,3230,287,8354,319,3431,13,775,821,10013,8031,11,3284,11,262,1578,4498,24880,11,290,262,42438,22931,21124,13,9938,503,508,338,9361,284,2498,4182,615,10055,262,13342,287,257,6614,13232,12387,416,262,4252,11,290,7301,262,11428,5585,286,1067,8605,287,7840,7229,13,921,1183,635,651,257,1570,286,5628,41336,326,373,4271,10395,329,39311,13,1550,428,2443,345,481,1064,1909,338,905,42978,290,257,1295,329,345,284,2581,284,307,319,262,8100,13613,3000,8299,4889,13,48213,6173,46023,764,6914,994,284,1895,262,14687,286,1909,338,8100,13613,3000,1430,13,4222,3465,326,612,743,307,257,5711,1022,262,640,618,262,2008,318,1695,290,618,262,14687,318,3199,13,8100,13613,3000,318,2727,416,257,1074,286,9046,508,2074,262,8070,7231,1812,20130,11,2260,5423,287,1180,2426,3006,11,290,1181,5423,618,9194,262,905,13,15107,3069,42815,764,1114,257,2863,284,307,4750,319,262,1306,8100,13613,3000,11,2912,319,262,4220,286,428,2443,351,534,1524,1438,11,37358,11,1748,290,1181,13,775,481,307,17246,4266,422,262,3651,286,262,2180,905,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,6952,345,329,1262,8100,13613,3000,0,198,198,21017,18261,25'
記事の概要を生成するには:
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ lm.generate \ ${SAX_CELL}/${MODEL_NAME} \ ${INPUT_STR}
次のような出力が想定されます。
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+ | GENERATE | SCORE | +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+ | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -0.023136413 | | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,50256 | -0.91842502 | | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -1.1726116 | | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -1.2472695 | +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
出力トークン ID の文字列をトークン化解除するには:
output_token_ids = [int(token_id) for token_id in OUTPUT_STR.split(',')] OUTPUT_TEXT = tokenizer.decode(output_token_ids, skip_special_tokens=True)
トークン化解除されたテキストは次のようになります。
>>> OUTPUT_TEXT 'This page includes the show Transcript.\nUse the Transcript to help students with reading comprehension and vocabulary.\nAt the bottom of the page, comment for a chance to be mentioned on CNN Student News. You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call.'
Docker コンテナと Cloud Storage ストレージ バケットをクリーンアップします。
175B マルチホスト モデル提供のプレビュー
一部の大規模言語モデルでは、マルチホスト TPU スライス(v5litepod-16 以降)が必要になります。このような場合、すべてのマルチホスト TPU ホストで SAX モデルサーバーのコピーが必要になり、すべてのモデルサーバーがマルチホスト TPU スライスで大規模なモデルを提供する SAX モデルサーバー グループとして機能します。
新しい SAX クラスタを作成する
GPT-J チュートリアルで SAX クラスタを作成するのと同じ手順に沿って、新しい SAX クラスタと SAX 管理サーバーを作成できます。
または、既存の SAX クラスタがある場合は、SAX クラスタにマルチホスト モデルサーバーを起動できます。
マルチホストの SAX モデルサーバーを SAX クラスタに起動する
単一ホスト TPU スライスの場合と同じコマンドを使用して、マルチホスト TPU スライスを作成します。その際必要な追加の手順は、適切なマルチホスト アクセラレータ タイプを指定することのみです。
ACCELERATOR_TYPE=v5litepod-32 ZONE=us-east1-c gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --service-account ${SERVICE_ACCOUNT} \ --reserved
SAX モデルサーバー イメージを、すべての TPU ホスト / ワーカーに pull して起動します。
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=" gcloud auth configure-docker \ us-docker.pkg.dev # Pull sax model server image docker pull ${SAX_MODEL_SERVER_IMAGE_URL} # Run model server docker run \ --privileged \ -it \ -d \ --rm \ --network host \ --name ${SAX_MODEL_SERVER_DOCKER_NAME} \ --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ ${SAX_MODEL_SERVER_IMAGE_URL} \ --sax_cell=${SAX_CELL} \ --port=10001 \ --platform_chip=tpuv4 \ --platform_topology=1x1"
SAX クラスタにモデルを公開する
この例では、LmCloudSpmd175B32Test モデルを使用します。
MODEL_NAME=lmcloudspmd175b32test MODEL_CONFIG_PATH=saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test CHECKPOINT_PATH=None REPLICA=1
テストモデルを公開するには:
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ publish \ ${SAX_CELL}/${MODEL_NAME} \ ${MODEL_CONFIG_PATH} \ ${CHECKPOINT_PATH} \ ${REPLICA}
推論結果を生成する
docker run \ ${SAX_UTIL_IMAGE_URL} \ --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \ lm.generate \ ${SAX_CELL}/${MODEL_NAME} \ "Q: Who is Harry Porter's mother? A\: "
この例ではランダムな重みを持つテストモデルを使用しているため、出力は意味がない可能性があることにご注意ください。
クリーンアップ
Docker コンテナを停止します。
docker stop ${SAX_ADMIN_SERVER_DOCKER_NAME} docker stop ${SAX_MODEL_SERVER_DOCKER_NAME}
次のように、
gsutil
を使用して Cloud Storage 管理ストレージ バケットとデータ ストレージ バケットを削除します。gsutil rm -rf gs://${SAX_ADMIN_STORAGE_BUCKET} gsutil rm -rf gs://${SAX_DATA_STORAGE_BUCKET}