Cloud TPU v5e での SAX

SAX クラスタ(SAX セル)

SAX 管理サーバーと SAX モデルサーバーは、SAX クラスタを実行する 2 つの重要なコンポーネントです。

SAX 管理サーバー

SAX 管理サーバーは、SAX クラスタ内のすべての SAX モデルサーバーをモニタリングし、調整します。SAX クラスタでは、複数の SAX 管理サーバーを起動できます。この場合、リーダーの選択で SAX 管理サーバーのうち 1 つだけがアクティブになり、その他はスタンバイ サーバーになります。アクティブな管理サーバーに障害が発生すると、スタンバイ管理サーバーがアクティブになります。アクティブな SAX 管理サーバーは、使用可能な SAX モデルサーバーにモデルレプリカと推論リクエストを割り当てます。

SAX 管理ストレージ バケット

各 SAX クラスタには、SAX 管理サーバーと SAX モデルサーバーの構成とロケーションを SAX クラスタに保存する Cloud Storage バケットが必要です。

SAX モデルサーバー

SAX モデルサーバーは、モデル チェックポイントを読み込み、GSPMD を使用して推論を実行します。SAX モデルサーバーは、1 つの TPU VM ワーカーで動作します。単一ホストの TPU モデルを提供するには、単一ホストの TPU VM 上に単一の SAX モデルサーバーが必要です。マルチホスト TPU モデルを提供するには、マルチホスト TPU スライス上に SAX モデルサーバー グループが必要です。マルチホスト モデルの提供は現在利用できませんが、このドキュメントではプレビュー用に 175B テストモデルの例を提供しています。

SAX モデルの提供

次のセクションでは、SAX を使用して言語モデルを提供するワークフローについて説明します。これは、単一ホストのモデル提供の例として GPT-J 6B モデルを使用しています。

始める前に、TPU VM に Cloud TPU SAX Docker イメージをインストールします。

sudo usermod -a -G docker ${USER}
newgrp docker

gcloud auth configure-docker us-docker.pkg.dev

SAX_ADMIN_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server"
SAX_MODEL_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server"
SAX_UTIL_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-util"

SAX_VERSION=v1.0.0

export SAX_ADMIN_SERVER_IMAGE_URL=${SAX_ADMIN_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_MODEL_SERVER_IMAGE_URL=${SAX_MODEL_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_UTIL_IMAGE_URL="${SAX_UTIL_IMAGE_NAME}:${sax_version}"

docker pull ${SAX_ADMIN_SERVER_IMAGE_URL}
docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
docker pull ${SAX_UTIL_IMAGE_URL}

後で使用するその他の変数を設定します。

export SAX_ADMIN_SERVER_DOCKER_NAME="sax-admin-server"
export SAX_MODEL_SERVER_DOCKER_NAME="sax-model-server"
export SAX_CELL="/sax/test"

GPT-J 6B 単一ホストモデル提供の例

単一ホストモデルの提供は、単一ホスト TPU スライス(v5litepod-1、v5litepod-4、v5litepod-8)に適用できます。

  1. SAX クラスタを作成する

    1. SAX クラスタ用の Cloud Storage ストレージ バケットを作成します。

      SAX_ADMIN_STORAGE_BUCKET=${your_admin_storage_bucket}
      
      gcloud storage buckets create gs://${SAX_ADMIN_STORAGE_BUCKET} \
      --project=${PROJECT_ID}
      

      チェックポイントを保存するために、別の Cloud Storage ストレージ バケットが必要になる場合があります。

      SAX_DATA_STORAGE_BUCKET=${your_data_storage_bucket}
      
    2. ターミナルの TPU VM に SSH 接続し、SAX 管理サーバーを起動します。

      docker run \
      --name ${SAX_ADMIN_SERVER_DOCKER_NAME} \
      -it \
      -d \
      --rm \
      --network host \
      --env GSBUCKET=${SAX_ADMIN_STORAGE_BUCKET} \
      ${SAX_ADMIN_SERVER_IMAGE_URL}
      

      Docker ログは次の方法で確認できます。

      docker logs -f ${SAX_ADMIN_SERVER_DOCKER_NAME}
      

      ログの出力は次のようになります。

      I0829 01:22:31.184198       7 config.go:111] Creating config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.347883       7 config.go:115] Created config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.360837      24 admin_server.go:44] Starting the server
      I0829 01:22:31.361420      24 ipaddr.go:39] Skipping non-global IP address 127.0.0.1/8.
      I0829 01:22:31.361455      24 ipaddr.go:39] Skipping non-global IP address ::1/128.
      I0829 01:22:31.361462      24 ipaddr.go:39] Skipping non-global IP address fe80::4001:aff:fe8e:fc8/64.
      I0829 01:22:31.361469      24 ipaddr.go:39] Skipping non-global IP address fe80::42:bfff:fef9:1bd3/64.
      I0829 01:22:31.361474      24 ipaddr.go:39] Skipping non-global IP address fe80::20fb:c3ff:fe5b:baac/64.
      I0829 01:22:31.361482      24 ipaddr.go:56] IPNet address 10.142.15.200
      I0829 01:22:31.361488      24 ipaddr.go:56] IPNet address 172.17.0.1
      I0829 01:22:31.456952      24 admin.go:305] Loaded config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.609323      24 addr.go:105] SetAddr /gcs/test_sax_admin/sax-root/sax/test/location.proto "10.142.15.200:10000"
      I0829 01:22:31.656021      24 admin.go:325] Updated config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.773245      24 mgr.go:781] Loaded manager state
      I0829 01:22:31.773260      24 mgr.go:784] Refreshing manager state every 10s
      I0829 01:22:31.773285      24 admin.go:350] Starting the server on port 10000
      I0829 01:22:31.773292      24 cloud.go:506] Starting the HTTP server on port 8080
      
  2. 単一ホストの SAX モデルサーバーを SAX クラスタに起動します。

    この時点で、SAX クラスタには SAX 管理サーバーのみが含まれています。第 2 のターミナルで SSH を介して TPU VM に接続し、SAX クラスタで SAX モデルサーバーを起動できます。

    docker run \
        --privileged  \
        -it \
        -d \
        --rm \
        --network host \
        --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
        --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        ${SAX_MODEL_SERVER_IMAGE_URL} \
           --sax_cell=${SAX_CELL} \
           --port=10001 \
           --platform_chip=tpuv4 \
           --platform_topology=1x1
    
  3. モデルのチェックポイントを変換します。

    PyTorch と Transformer をインストールして、EleutherAI から GPT-J チェックポイントをダウンロードする必要があります。

    pip3 install accelerate
    pip3 install torch
    pip3 install transformers
    

    チェックポイントを SAX チェックポイントに変換するには、paxml をインストールする必要があります。

    pip3 install paxml==1.1.0
    

    次のスクリプトは、GPT-J チェックポイントを SAX チェックポイントに変換します。

    python3 -m convert_gptj_ckpt --base EleutherAI/gpt-j-6b --pax pax_6b
    

    変換後:

    ls checkpoint_00000000/
    

    commit_success ファイルを作成し、サブディレクトリに配置する必要があります。

    gsutil -m cp -r checkpoint_00000000 ${CHECKPOINT_PATH}
    
    touch commit_success.txt
    gsutil cp commit_success.txt ${CHECKPOINT_PATH}/
    gsutil cp commit_success.txt ${CHECKPOINT_PATH}/metadata/
    gsutil cp commit_success.txt ${CHECKPOINT_PATH}/state/
    
  4. SAX クラスタにモデルを公開する

    これで、前のステップで変換したチェックポイントを使用して GPT-J を公開できるようになりました。

    MODEL_NAME=gptjtokenizedbf16bs32
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4TokenizedBF16BS32
    REPLICA=1
    

    GPT-J(およびその後の手順)を公開するには、SSH を使用して第 3 のターミナルで TPU VM に接続します。

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}
    

    モデルが正常に読み込まれたことを示す次のような出力が表示されるまで、モデルサーバーの Docker ログの多くのアクティビティが表示されます。

    I0829 01:33:49.287459 139865140229696 servable_model.py:697] loading completed.
    
  5. 推論結果を生成する

    GPT-J では、入出力はカンマ区切りのトークン ID 文字列としてフォーマットされる必要があります。テキスト入力はトークン化する必要があります。

    TEXT = "Below is an instruction that describes a task, paired with
    an input that provides further context. Write a response that
    appropriately completes the request.\n\n### Instruction\:\nSummarize the
    following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly
    international in scope on Tuesday. We're visiting Italy, Russia, the
    United Arab Emirates, and the Himalayan Mountains. Find out who's
    attempting to circumnavigate the globe in a plane powered partially by the
    sun, and explore the mysterious appearance of craters in northern Asia.
    You'll also get a view of Mount Everest that was previously reserved for
    climbers. On this page you will find today's show Transcript and a place
    for you to request to be on the CNN Student News Roll Call. TRANSCRIPT .
    Click here to access the transcript of today's CNN Student News program.
    Please note that there may be a delay between the time when the video is
    available and when the transcript is published. CNN Student News is
    created by a team of journalists who consider the Common Core State
    Standards, national standards in different subject areas, and state
    standards when producing the show. ROLL CALL . For a chance to be
    mentioned on the next CNN Student News, comment on the bottom of this page
    with your school name, mascot, city and state. We will be selecting
    schools from the comments of the previous show. You must be a teacher or a
    student age 13 or older to request a mention on the CNN Student News Roll
    Call! Thank you for using CNN Student News!\n\n### Response\:
    

    トークン ID 文字列は、EleutherAI/gpt-j-6b トークナイザを使用して取得できます。

    from transformers import GPT2Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-j-6b")                  :
    

    入力テキストをトークン化します。

    encoded_example = tokenizer(TEXT)
    input_ids = encoded_example.input_ids
    INPUT_STR = ",".join([str(input_id) for input_id in input_ids])
    

    次のようなトークン ID 文字列が想定されます。

    >>> INPUT_STR
    '21106,318,281,12064,326,8477,257,4876,11,20312,351,281,5128,326,3769,2252,4732,13,19430,257,2882,326,20431,32543,262,2581,13,198,198,21017,46486,25,198,13065,3876,1096,262,1708,1705,2708,25,198,198,21017,23412,25,198,16192,838,11,1853,764,775,821,4988,3230,287,8354,319,3431,13,775,821,10013,8031,11,3284,11,262,1578,4498,24880,11,290,262,42438,22931,21124,13,9938,503,508,338,9361,284,2498,4182,615,10055,262,13342,287,257,6614,13232,12387,416,262,4252,11,290,7301,262,11428,5585,286,1067,8605,287,7840,7229,13,921,1183,635,651,257,1570,286,5628,41336,326,373,4271,10395,329,39311,13,1550,428,2443,345,481,1064,1909,338,905,42978,290,257,1295,329,345,284,2581,284,307,319,262,8100,13613,3000,8299,4889,13,48213,6173,46023,764,6914,994,284,1895,262,14687,286,1909,338,8100,13613,3000,1430,13,4222,3465,326,612,743,307,257,5711,1022,262,640,618,262,2008,318,1695,290,618,262,14687,318,3199,13,8100,13613,3000,318,2727,416,257,1074,286,9046,508,2074,262,8070,7231,1812,20130,11,2260,5423,287,1180,2426,3006,11,290,1181,5423,618,9194,262,905,13,15107,3069,42815,764,1114,257,2863,284,307,4750,319,262,1306,8100,13613,3000,11,2912,319,262,4220,286,428,2443,351,534,1524,1438,11,37358,11,1748,290,1181,13,775,481,307,17246,4266,422,262,3651,286,262,2180,905,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,6952,345,329,1262,8100,13613,3000,0,198,198,21017,18261,25'
    

    記事の概要を生成するには:

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          ${INPUT_STR}
    

    次のような出力が想定されます。

    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    |                                                                                                                                                    GENERATE                                                                                                                                                    |    SCORE     |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -0.023136413 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,50256  |  -0.91842502 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256     |   -1.1726116 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256                            |   -1.2472695 |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    

    出力トークン ID の文字列をトークン化解除するには:

    output_token_ids = [int(token_id) for token_id in OUTPUT_STR.split(',')]
    OUTPUT_TEXT = tokenizer.decode(output_token_ids, skip_special_tokens=True)
    

    トークン化解除されたテキストは次のようになります。

    >>> OUTPUT_TEXT
    'This page includes the show Transcript.\nUse the Transcript to help
    students with reading comprehension and vocabulary.\nAt the bottom of
    the page, comment for a chance to be mentioned on CNN Student News.
    You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call.'
    
  6. Docker コンテナと Cloud Storage ストレージ バケットをクリーンアップします。

175B マルチホスト モデル提供のプレビュー

一部の大規模言語モデルでは、マルチホスト TPU スライス(v5litepod-16 以降)が必要になります。このような場合、すべてのマルチホスト TPU ホストで SAX モデルサーバーのコピーが必要になり、すべてのモデルサーバーがマルチホスト TPU スライスで大規模なモデルを提供する SAX モデルサーバー グループとして機能します。

  1. 新しい SAX クラスタを作成する

    GPT-J チュートリアルで SAX クラスタを作成するのと同じ手順に沿って、新しい SAX クラスタと SAX 管理サーバーを作成できます。

    または、既存の SAX クラスタがある場合は、SAX クラスタにマルチホスト モデルサーバーを起動できます。

  2. マルチホストの SAX モデルサーバーを SAX クラスタに起動する

    単一ホスト TPU スライスの場合と同じコマンドを使用して、マルチホスト TPU スライスを作成します。その際必要な追加の手順は、適切なマルチホスト アクセラレータ タイプを指定することのみです。

    ACCELERATOR_TYPE=v5litepod-32
    ZONE=us-east1-c
    
    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --service-account ${SERVICE_ACCOUNT} \
      --reserved
    

    SAX モデルサーバー イメージを、すべての TPU ホスト / ワーカーに pull して起動します。

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --worker=all \
      --command="
        gcloud auth configure-docker \
          us-docker.pkg.dev
        # Pull sax model server image
        docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
        # Run model server
        docker run \
          --privileged  \
          -it \
          -d \
          --rm \
          --network host \
          --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
          --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
          ${SAX_MODEL_SERVER_IMAGE_URL} \
            --sax_cell=${SAX_CELL} \
            --port=10001 \
            --platform_chip=tpuv4 \
            --platform_topology=1x1"
    
  3. SAX クラスタにモデルを公開する

    この例では、LmCloudSpmd175B32Test モデルを使用します。

    MODEL_NAME=lmcloudspmd175b32test
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test
    CHECKPOINT_PATH=None
    REPLICA=1
    

    テストモデルを公開するには:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}
    
  4. 推論結果を生成する

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          "Q:  Who is Harry Porter's mother? A\: "
    

    この例ではランダムな重みを持つテストモデルを使用しているため、出力は意味がない可能性があることにご注意ください。

  5. クリーンアップ

    Docker コンテナを停止します。

    docker stop ${SAX_ADMIN_SERVER_DOCKER_NAME}
    docker stop ${SAX_MODEL_SERVER_DOCKER_NAME}
    

    次のように、gsutil を使用して Cloud Storage 管理ストレージ バケットとデータ ストレージ バケットを削除します。

    gsutil rm -rf gs://${SAX_ADMIN_STORAGE_BUCKET}
    gsutil rm -rf gs://${SAX_DATA_STORAGE_BUCKET}