Cloud TPU v5e 上的 SAX

SAX 集群(SAX 单元)

SAX 管理服务器和 SAX 模型服务器是运行 SAX 集群的两个基本组件。

SAX 管理服务器

SAX 管理服务器监控和协调 SAX 集群中的所有 SAX 模型服务器。在 SAX 集群中,您可以启动多个 SAX 管理服务器,其中只有一个 SAX 管理服务器通过领导者选举处于活跃状态,其他是备用服务器。 当活跃管理服务器发生故障时,备用管理服务器将变为活跃状态。活跃的 SAX 管理服务器将模型副本和推断请求分配给可用的 SAX 模型服务器。

SAX 管理员存储桶

每个 SAX 集群都需要一个 Cloud Storage 存储桶,以将 SAX 管理服务器和 SAX 模型服务器的配置和位置存储在 SAX 集群中。

SAX 模型服务器

SAX 模型服务器会加载模型检查点,并使用 GSPMD 进行推断。SAX 模型服务器在单个 TPU 虚拟机工作器上运行。单主机 TPU 模型服务需要在单主机 TPU 虚拟机上配置单个 SAX 模型服务器。多主机 TPU 模型服务需要在多主机 TPU 切片上包含一组 SAX 模型服务器。多主机模型服务目前不可用,但本文档提供了包含 175B 测试模型的示例以供预览。

SAX 模型服务

以下部分介绍了使用 SAX 提供语言模型的工作流。它使用 GPT-J 6B 模型作为单主机模型投放示例。

开始之前,请在 TPU 虚拟机上安装 Cloud TPU SAX Docker 映像:

sudo usermod -a -G docker ${USER}
newgrp docker

gcloud auth configure-docker us-docker.pkg.dev

SAX_ADMIN_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server"
SAX_MODEL_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server"
SAX_UTIL_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-util"

SAX_VERSION=v1.0.0

export SAX_ADMIN_SERVER_IMAGE_URL=${SAX_ADMIN_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_MODEL_SERVER_IMAGE_URL=${SAX_MODEL_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_UTIL_IMAGE_URL="${SAX_UTIL_IMAGE_NAME}:${sax_version}"

docker pull ${SAX_ADMIN_SERVER_IMAGE_URL}
docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
docker pull ${SAX_UTIL_IMAGE_URL}

设置稍后要使用的一些其他变量:

export SAX_ADMIN_SERVER_DOCKER_NAME="sax-admin-server"
export SAX_MODEL_SERVER_DOCKER_NAME="sax-model-server"
export SAX_CELL="/sax/test"

GPT-J 6B 单主机模型投放示例

单主机模型服务适用于单主机 TPU 切片,即 v5litepod-1、v5litepod-4 和 v5litepod-8。

  1. 创建 SAX 集群

    1. 为 SAX 集群创建 Cloud Storage 存储桶:

      SAX_ADMIN_STORAGE_BUCKET=${your_admin_storage_bucket}
      
      gcloud storage buckets create gs://${SAX_ADMIN_STORAGE_BUCKET} \
      --project=${PROJECT_ID}
      

      您可能需要另一个 Cloud Storage 存储桶来存储检查点。

      SAX_DATA_STORAGE_BUCKET=${your_data_storage_bucket}
      
    2. 在终端中通过 SSH 连接到 TPU 虚拟机,以启动 SAX 管理服务器:

      docker run \
      --name ${SAX_ADMIN_SERVER_DOCKER_NAME} \
      -it \
      -d \
      --rm \
      --network host \
      --env GSBUCKET=${SAX_ADMIN_STORAGE_BUCKET} \
      ${SAX_ADMIN_SERVER_IMAGE_URL}
      

      您可以通过以下方式查看 Docker 日志:

      docker logs -f ${SAX_ADMIN_SERVER_DOCKER_NAME}
      

      日志的输出将类似于以下内容:

      I0829 01:22:31.184198       7 config.go:111] Creating config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.347883       7 config.go:115] Created config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.360837      24 admin_server.go:44] Starting the server
      I0829 01:22:31.361420      24 ipaddr.go:39] Skipping non-global IP address 127.0.0.1/8.
      I0829 01:22:31.361455      24 ipaddr.go:39] Skipping non-global IP address ::1/128.
      I0829 01:22:31.361462      24 ipaddr.go:39] Skipping non-global IP address fe80::4001:aff:fe8e:fc8/64.
      I0829 01:22:31.361469      24 ipaddr.go:39] Skipping non-global IP address fe80::42:bfff:fef9:1bd3/64.
      I0829 01:22:31.361474      24 ipaddr.go:39] Skipping non-global IP address fe80::20fb:c3ff:fe5b:baac/64.
      I0829 01:22:31.361482      24 ipaddr.go:56] IPNet address 10.142.15.200
      I0829 01:22:31.361488      24 ipaddr.go:56] IPNet address 172.17.0.1
      I0829 01:22:31.456952      24 admin.go:305] Loaded config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.609323      24 addr.go:105] SetAddr /gcs/test_sax_admin/sax-root/sax/test/location.proto "10.142.15.200:10000"
      I0829 01:22:31.656021      24 admin.go:325] Updated config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.773245      24 mgr.go:781] Loaded manager state
      I0829 01:22:31.773260      24 mgr.go:784] Refreshing manager state every 10s
      I0829 01:22:31.773285      24 admin.go:350] Starting the server on port 10000
      I0829 01:22:31.773292      24 cloud.go:506] Starting the HTTP server on port 8080
      
  2. 将单主机 SAX 模型服务器启动到 SAX 集群中:

    此时,SAX 集群仅包含 SAX 管理服务器。 您可以在第二个终端中通过 SSH 连接到 TPU 虚拟机,以便在 SAX 集群中启动 SAX 模型服务器:

    docker run \
        --privileged  \
        -it \
        -d \
        --rm \
        --network host \
        --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
        --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        ${SAX_MODEL_SERVER_IMAGE_URL} \
           --sax_cell=${SAX_CELL} \
           --port=10001 \
           --platform_chip=tpuv4 \
           --platform_topology=1x1
    
  3. 转换模型检查点:

    您需要安装 PyTorch 和 Transformer,才能从 EleutherAI 下载 GPT-J 检查点:

    pip3 install accelerate
    pip3 install torch
    pip3 install transformers
    

    如需将检查点转换为 SAX 检查点,您需要安装 paxml

    pip3 install paxml==1.1.0
    

    以下脚本会将 GPT-J 检查点转换为 SAX 检查点:

    python3 -m convert_gptj_ckpt --base EleutherAI/gpt-j-6b --pax pax_6b
    

    转换完成后:

    ls checkpoint_00000000/
    

    您需要创建一个 commit_success 文件,并将其放置在子目录中:

    gsutil -m cp -r checkpoint_00000000 ${CHECKPOINT_PATH}
    
    touch commit_success.txt
    gsutil cp commit_success.txt ${CHECKPOINT_PATH}/
    gsutil cp commit_success.txt ${CHECKPOINT_PATH}/metadata/
    gsutil cp commit_success.txt ${CHECKPOINT_PATH}/state/
    
  4. 将模型发布到 SAX 集群

    现在,您可以发布包含已在上一步中转换的检查点的 GPT-J 了。

    MODEL_NAME=gptjtokenizedbf16bs32
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4TokenizedBF16BS32
    REPLICA=1
    

    如需发布 GPT-J(以及之后的步骤),请在第三个终端中使用 SSH 连接到您的 TPU 虚拟机:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}
    

    您将看到模型服务器 Docker 日志中的大量活动,直到您看到如下内容指示模型已成功加载:

    I0829 01:33:49.287459 139865140229696 servable_model.py:697] loading completed.
    
  5. 生成推断结果

    对于 GPT-J,输入和输出的格式必须为以英文逗号分隔的令牌 ID 字符串。您将需要对文本输入进行词元化处理。

    TEXT = "Below is an instruction that describes a task, paired with
    an input that provides further context. Write a response that
    appropriately completes the request.\n\n### Instruction\:\nSummarize the
    following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly
    international in scope on Tuesday. We're visiting Italy, Russia, the
    United Arab Emirates, and the Himalayan Mountains. Find out who's
    attempting to circumnavigate the globe in a plane powered partially by the
    sun, and explore the mysterious appearance of craters in northern Asia.
    You'll also get a view of Mount Everest that was previously reserved for
    climbers. On this page you will find today's show Transcript and a place
    for you to request to be on the CNN Student News Roll Call. TRANSCRIPT .
    Click here to access the transcript of today's CNN Student News program.
    Please note that there may be a delay between the time when the video is
    available and when the transcript is published. CNN Student News is
    created by a team of journalists who consider the Common Core State
    Standards, national standards in different subject areas, and state
    standards when producing the show. ROLL CALL . For a chance to be
    mentioned on the next CNN Student News, comment on the bottom of this page
    with your school name, mascot, city and state. We will be selecting
    schools from the comments of the previous show. You must be a teacher or a
    student age 13 or older to request a mention on the CNN Student News Roll
    Call! Thank you for using CNN Student News!\n\n### Response\:
    

    您可以通过 EleutherAI/gpt-j-6b 标记生成器获取令牌 ID 字符串:

    from transformers import GPT2Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-j-6b")                  :
    

    对输入文本进行词元化处理:

    encoded_example = tokenizer(TEXT)
    input_ids = encoded_example.input_ids
    INPUT_STR = ",".join([str(input_id) for input_id in input_ids])
    

    您应该得到类似于以下内容的令牌 ID 字符串:

    >>> INPUT_STR
    '21106,318,281,12064,326,8477,257,4876,11,20312,351,281,5128,326,3769,2252,4732,13,19430,257,2882,326,20431,32543,262,2581,13,198,198,21017,46486,25,198,13065,3876,1096,262,1708,1705,2708,25,198,198,21017,23412,25,198,16192,838,11,1853,764,775,821,4988,3230,287,8354,319,3431,13,775,821,10013,8031,11,3284,11,262,1578,4498,24880,11,290,262,42438,22931,21124,13,9938,503,508,338,9361,284,2498,4182,615,10055,262,13342,287,257,6614,13232,12387,416,262,4252,11,290,7301,262,11428,5585,286,1067,8605,287,7840,7229,13,921,1183,635,651,257,1570,286,5628,41336,326,373,4271,10395,329,39311,13,1550,428,2443,345,481,1064,1909,338,905,42978,290,257,1295,329,345,284,2581,284,307,319,262,8100,13613,3000,8299,4889,13,48213,6173,46023,764,6914,994,284,1895,262,14687,286,1909,338,8100,13613,3000,1430,13,4222,3465,326,612,743,307,257,5711,1022,262,640,618,262,2008,318,1695,290,618,262,14687,318,3199,13,8100,13613,3000,318,2727,416,257,1074,286,9046,508,2074,262,8070,7231,1812,20130,11,2260,5423,287,1180,2426,3006,11,290,1181,5423,618,9194,262,905,13,15107,3069,42815,764,1114,257,2863,284,307,4750,319,262,1306,8100,13613,3000,11,2912,319,262,4220,286,428,2443,351,534,1524,1438,11,37358,11,1748,290,1181,13,775,481,307,17246,4266,422,262,3651,286,262,2180,905,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,6952,345,329,1262,8100,13613,3000,0,198,198,21017,18261,25'
    

    如需生成报道的摘要,请执行以下操作:

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          ${INPUT_STR}
    

    您应该会看到类似如下所示的情况:

    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    |                                                                                                                                                    GENERATE                                                                                                                                                    |    SCORE     |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -0.023136413 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,50256  |  -0.91842502 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256     |   -1.1726116 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256                            |   -1.2472695 |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    

    要对输出令牌 ID 字符串进行去令牌化处理,请执行以下操作:

    output_token_ids = [int(token_id) for token_id in OUTPUT_STR.split(',')]
    OUTPUT_TEXT = tokenizer.decode(output_token_ids, skip_special_tokens=True)
    

    去令牌化后的文本应如下所示:

    >>> OUTPUT_TEXT
    'This page includes the show Transcript.\nUse the Transcript to help
    students with reading comprehension and vocabulary.\nAt the bottom of
    the page, comment for a chance to be mentioned on CNN Student News.
    You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call.'
    
  6. 清理 Docker 容器和 Cloud Storage 存储分区。

1750 亿多主机模型服务预览

一些大型语言模型需要多主机 TPU 切片,即 v5litepod-16 及更高版本。在这些情况下,所有多主机 TPU 主机都需要有一个 SAX 模型服务器的副本,并且所有模型服务器充当 SAX 模型服务器组,以在多主机 TPU 切片上提供大型模型。

  1. 创建新的 SAX 集群

    您可以按照 GPT-J 演示中“创建 SAX 集群”的同一步骤进行操作,以创建新的 SAX 集群和 SAX 管理服务器。

    或者,如果您已有 SAX 集群,则可以在 SAX 集群中启动多主机模型服务器。

  2. 将多主机 SAX 模型服务器启动到 SAX 集群中

    使用与用于单主机 TPU 切片相同的命令创建多主机 TPU 切片,只需指定相应的多主机加速器类型即可:

    ACCELERATOR_TYPE=v5litepod-32
    ZONE=us-east1-c
    
    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --service-account ${SERVICE_ACCOUNT} \
      --reserved
    

    如需将 SAX 模型服务器映像拉取到所有 TPU 主机/工作器并启动它们,请执行以下操作:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --worker=all \
      --command="
        gcloud auth configure-docker \
          us-docker.pkg.dev
        # Pull sax model server image
        docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
        # Run model server
        docker run \
          --privileged  \
          -it \
          -d \
          --rm \
          --network host \
          --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
          --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
          ${SAX_MODEL_SERVER_IMAGE_URL} \
            --sax_cell=${SAX_CELL} \
            --port=10001 \
            --platform_chip=tpuv4 \
            --platform_topology=1x1"
    
  3. 将模型发布到 SAX 集群

    此示例使用 LmCloudSpmd175B32Test 模型:

    MODEL_NAME=lmcloudspmd175b32test
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test
    CHECKPOINT_PATH=None
    REPLICA=1
    

    如需发布测试模型,请执行以下操作:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}
    
  4. 生成推断结果

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          "Q:  Who is Harry Porter's mother? A\: "
    

    请注意,由于此示例使用的是具有随机权重的测试模型,因此输出可能没有意义。

  5. 清理

    停止 Docker 容器:

    docker stop ${SAX_ADMIN_SERVER_DOCKER_NAME}
    docker stop ${SAX_MODEL_SERVER_DOCKER_NAME}
    

    使用 gsutil 删除 Cloud Storage 管理员存储桶和任何数据存储桶,如下所示。

    gsutil rm -rf gs://${SAX_ADMIN_STORAGE_BUCKET}
    gsutil rm -rf gs://${SAX_DATA_STORAGE_BUCKET}