SAX sur Cloud TPU v5e

Cluster SAX (cellule SAX)

Le serveur d'administration SAX et le serveur de modèles SAX sont deux composants essentiels qui exécutent un cluster SAX.

Serveur d'administration SAX

Le serveur d'administration SAX surveille et coordonne tous les serveurs de modèles SAX d'un cluster SAX. Dans un cluster SAX, vous pouvez lancer plusieurs serveurs d'administration SAX, où un seul serveur d'administration SAX est actif lors de l'élection du responsable, les autres étant des serveurs de secours. En cas de défaillance du serveur d'administration actif, un serveur d'administration de secours devient actif. Le serveur d'administration SAX actif attribue des instances répliquées de modèle et des requêtes d'inférence aux serveurs de modèles SAX disponibles.

Bucket de stockage administrateur SAX

Chaque cluster SAX nécessite un bucket Cloud Storage pour stocker la configuration et l'emplacement des serveurs d'administration SAX et des serveurs de modèles SAX dans le cluster SAX.

Serveur de modèles SAX

Le serveur de modèles SAX charge un point de contrôle du modèle et exécute une inférence avec GSPMD. Un serveur de modèles SAX s'exécute sur un seul nœud de calcul de VM TPU. La diffusion de modèles TPU à hôte unique nécessite un seul serveur de modèles SAX sur une VM TPU à hôte unique. L'inférence de modèles TPU à hôtes multiples nécessite un groupe de serveurs de modèles SAX sur une tranche de TPU multi-hôte. La diffusion de modèles à hôtes multiples n'est pas disponible actuellement, mais ce document fournit un exemple avec un modèle de test de 175 milliards à titre de prévisualisation.

Inférence du modèle SAX

La section suivante décrit le workflow de diffusion des modèles de langage à l'aide de SAX. Le modèle GPT-J 6B est utilisé comme exemple de diffusion du modèle à hôte unique.

Avant de commencer, installez les images Docker Cloud TPU SAX sur votre VM TPU:

sudo usermod -a -G docker ${USER}
newgrp docker

gcloud auth configure-docker us-docker.pkg.dev

SAX_ADMIN_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-admin-server"
SAX_MODEL_SERVER_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-model-server"
SAX_UTIL_IMAGE_NAME="us-docker.pkg.dev/cloud-tpu-images/inference/sax-util"

SAX_VERSION=v1.0.0

export SAX_ADMIN_SERVER_IMAGE_URL=${SAX_ADMIN_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_MODEL_SERVER_IMAGE_URL=${SAX_MODEL_SERVER_IMAGE_NAME}:${SAX_VERSION}
export SAX_UTIL_IMAGE_URL="${SAX_UTIL_IMAGE_NAME}:${sax_version}"

docker pull ${SAX_ADMIN_SERVER_IMAGE_URL}
docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
docker pull ${SAX_UTIL_IMAGE_URL}

Définissez d'autres variables que vous utiliserez plus tard:

export SAX_ADMIN_SERVER_DOCKER_NAME="sax-admin-server"
export SAX_MODEL_SERVER_DOCKER_NAME="sax-model-server"
export SAX_CELL="/sax/test"

Exemple de diffusion avec un modèle GPT-J 6B à hôte unique

La diffusion des modèles à hôte unique s'applique à la tranche de TPU à hôte unique, c'est-à-dire v5litepod-1, v5litepod-4 et v5litepod-8.

  1. Créer un cluster SAX

    1. Créez un bucket de stockage Cloud Storage pour le cluster SAX:

      SAX_ADMIN_STORAGE_BUCKET=${your_admin_storage_bucket}
      
      gcloud storage buckets create gs://${SAX_ADMIN_STORAGE_BUCKET} \
      --project=${PROJECT_ID}
      

      Vous aurez peut-être besoin d'un autre bucket de stockage Cloud Storage pour stocker le point de contrôle.

      SAX_DATA_STORAGE_BUCKET=${your_data_storage_bucket}
      
    2. Connectez-vous en SSH à votre VM TPU dans un terminal pour lancer le serveur d'administration SAX:

      docker run \
      --name ${SAX_ADMIN_SERVER_DOCKER_NAME} \
      -it \
      -d \
      --rm \
      --network host \
      --env GSBUCKET=${SAX_ADMIN_STORAGE_BUCKET} \
      ${SAX_ADMIN_SERVER_IMAGE_URL}
      

      Vous pouvez consulter le journal Docker en procédant comme suit:

      docker logs -f ${SAX_ADMIN_SERVER_DOCKER_NAME}
      

      Le résultat du journal ressemble à ce qui suit:

      I0829 01:22:31.184198       7 config.go:111] Creating config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.347883       7 config.go:115] Created config fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.360837      24 admin_server.go:44] Starting the server
      I0829 01:22:31.361420      24 ipaddr.go:39] Skipping non-global IP address 127.0.0.1/8.
      I0829 01:22:31.361455      24 ipaddr.go:39] Skipping non-global IP address ::1/128.
      I0829 01:22:31.361462      24 ipaddr.go:39] Skipping non-global IP address fe80::4001:aff:fe8e:fc8/64.
      I0829 01:22:31.361469      24 ipaddr.go:39] Skipping non-global IP address fe80::42:bfff:fef9:1bd3/64.
      I0829 01:22:31.361474      24 ipaddr.go:39] Skipping non-global IP address fe80::20fb:c3ff:fe5b:baac/64.
      I0829 01:22:31.361482      24 ipaddr.go:56] IPNet address 10.142.15.200
      I0829 01:22:31.361488      24 ipaddr.go:56] IPNet address 172.17.0.1
      I0829 01:22:31.456952      24 admin.go:305] Loaded config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.609323      24 addr.go:105] SetAddr /gcs/test_sax_admin/sax-root/sax/test/location.proto "10.142.15.200:10000"
      I0829 01:22:31.656021      24 admin.go:325] Updated config: fs_root: "gs://test_sax_admin/sax-fs-root"
      I0829 01:22:31.773245      24 mgr.go:781] Loaded manager state
      I0829 01:22:31.773260      24 mgr.go:784] Refreshing manager state every 10s
      I0829 01:22:31.773285      24 admin.go:350] Starting the server on port 10000
      I0829 01:22:31.773292      24 cloud.go:506] Starting the HTTP server on port 8080
      
  2. Lancez un serveur de modèle SAX à hôte unique dans le cluster SAX:

    À ce stade, le cluster SAX ne contient que le serveur d'administration SAX. Vous pouvez vous connecter à votre VM TPU via SSH dans un deuxième terminal pour lancer un serveur de modèle SAX dans votre cluster SAX:

    docker run \
        --privileged  \
        -it \
        -d \
        --rm \
        --network host \
        --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
        --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        ${SAX_MODEL_SERVER_IMAGE_URL} \
           --sax_cell=${SAX_CELL} \
           --port=10001 \
           --platform_chip=tpuv4 \
           --platform_topology=1x1
    
  3. Point de contrôle du modèle:

    Vous devez installer PyTorch et Transformers pour télécharger le point de contrôle GPT-J depuis EleutherAI:

    pip3 install accelerate
    pip3 install torch
    pip3 install transformers
    

    Pour convertir le point de contrôle en point de contrôle SAX, vous devez installer paxml:

    pip3 install paxml==1.1.0
    

    Le script suivant convertit le point de contrôle GPT-J en point de contrôle SAX:

    python3 -m convert_gptj_ckpt --base EleutherAI/gpt-j-6b --pax pax_6b
    

    Une fois la conversion terminée:

    ls checkpoint_00000000/
    

    Vous devez créer un fichier commit_success et le placer dans les sous-répertoires suivants:

    gsutil -m cp -r checkpoint_00000000 ${CHECKPOINT_PATH}
    
    touch commit_success.txt
    gsutil cp commit_success.txt ${CHECKPOINT_PATH}/
    gsutil cp commit_success.txt ${CHECKPOINT_PATH}/metadata/
    gsutil cp commit_success.txt ${CHECKPOINT_PATH}/state/
    
  4. Publier le modèle sur le cluster SAX

    Vous pouvez maintenant publier GPT-J avec le point de contrôle converti à l'étape précédente.

    MODEL_NAME=gptjtokenizedbf16bs32
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.gptj.GPTJ4TokenizedBF16BS32
    REPLICA=1
    

    Pour publier le tag GPT-J (et les étapes suivantes), utilisez SSH pour vous connecter à votre VM TPU dans un troisième terminal:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}
    

    Vous constaterez une grande activité dans le journal Docker du serveur de modèles jusqu'à ce que vous voyiez une sortie semblable à celle-ci pour indiquer que le modèle a bien été chargé:

    I0829 01:33:49.287459 139865140229696 servable_model.py:697] loading completed.
    
  5. Générer des résultats d'inférence

    Pour GPT-J, l'entrée et la sortie doivent être formatées en une chaîne d'ID de jeton séparée par des virgules. Vous devez tokeniser l'entrée de texte.

    TEXT = "Below is an instruction that describes a task, paired with
    an input that provides further context. Write a response that
    appropriately completes the request.\n\n### Instruction\:\nSummarize the
    following news article\:\n\n### Input\:\nMarch 10, 2015 . We're truly
    international in scope on Tuesday. We're visiting Italy, Russia, the
    United Arab Emirates, and the Himalayan Mountains. Find out who's
    attempting to circumnavigate the globe in a plane powered partially by the
    sun, and explore the mysterious appearance of craters in northern Asia.
    You'll also get a view of Mount Everest that was previously reserved for
    climbers. On this page you will find today's show Transcript and a place
    for you to request to be on the CNN Student News Roll Call. TRANSCRIPT .
    Click here to access the transcript of today's CNN Student News program.
    Please note that there may be a delay between the time when the video is
    available and when the transcript is published. CNN Student News is
    created by a team of journalists who consider the Common Core State
    Standards, national standards in different subject areas, and state
    standards when producing the show. ROLL CALL . For a chance to be
    mentioned on the next CNN Student News, comment on the bottom of this page
    with your school name, mascot, city and state. We will be selecting
    schools from the comments of the previous show. You must be a teacher or a
    student age 13 or older to request a mention on the CNN Student News Roll
    Call! Thank you for using CNN Student News!\n\n### Response\:
    

    Vous pouvez obtenir la chaîne d'ID de jeton via le jeton EleutherAI/gpt-j-6b:

    from transformers import GPT2Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-j-6b")                  :
    

    Tokeniser le texte d'entrée:

    encoded_example = tokenizer(TEXT)
    input_ids = encoded_example.input_ids
    INPUT_STR = ",".join([str(input_id) for input_id in input_ids])
    

    Vous pouvez vous attendre à une chaîne d'ID de jeton semblable à celle-ci:

    >>> INPUT_STR
    '21106,318,281,12064,326,8477,257,4876,11,20312,351,281,5128,326,3769,2252,4732,13,19430,257,2882,326,20431,32543,262,2581,13,198,198,21017,46486,25,198,13065,3876,1096,262,1708,1705,2708,25,198,198,21017,23412,25,198,16192,838,11,1853,764,775,821,4988,3230,287,8354,319,3431,13,775,821,10013,8031,11,3284,11,262,1578,4498,24880,11,290,262,42438,22931,21124,13,9938,503,508,338,9361,284,2498,4182,615,10055,262,13342,287,257,6614,13232,12387,416,262,4252,11,290,7301,262,11428,5585,286,1067,8605,287,7840,7229,13,921,1183,635,651,257,1570,286,5628,41336,326,373,4271,10395,329,39311,13,1550,428,2443,345,481,1064,1909,338,905,42978,290,257,1295,329,345,284,2581,284,307,319,262,8100,13613,3000,8299,4889,13,48213,6173,46023,764,6914,994,284,1895,262,14687,286,1909,338,8100,13613,3000,1430,13,4222,3465,326,612,743,307,257,5711,1022,262,640,618,262,2008,318,1695,290,618,262,14687,318,3199,13,8100,13613,3000,318,2727,416,257,1074,286,9046,508,2074,262,8070,7231,1812,20130,11,2260,5423,287,1180,2426,3006,11,290,1181,5423,618,9194,262,905,13,15107,3069,42815,764,1114,257,2863,284,307,4750,319,262,1306,8100,13613,3000,11,2912,319,262,4220,286,428,2443,351,534,1524,1438,11,37358,11,1748,290,1181,13,775,481,307,17246,4266,422,262,3651,286,262,2180,905,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,6952,345,329,1262,8100,13613,3000,0,198,198,21017,18261,25'
    

    Pour générer un résumé de votre article:

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          ${INPUT_STR}
    

    Le résultat devrait ressembler à ceci:

    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    |                                                                                                                                                    GENERATE                                                                                                                                                    |    SCORE     |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256 | -0.023136413 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,0,50256  |  -0.91842502 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,921,1276,307,257,4701,393,257,3710,2479,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256     |   -1.1726116 |
    | 1212,2443,3407,262,905,42978,764,198,11041,262,42978,284,1037,2444,351,3555,35915,290,25818,764,198,2953,262,4220,286,262,2443,11,2912,329,257,2863,284,307,4750,319,8100,13613,3000,13,220,921,1276,307,1511,393,4697,284,2581,257,3068,319,262,8100,13613,3000,8299,4889,13,50256                            |   -1.2472695 |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------+
    

    Pour détokeniser la chaîne d'ID des jetons de sortie:

    output_token_ids = [int(token_id) for token_id in OUTPUT_STR.split(',')]
    OUTPUT_TEXT = tokenizer.decode(output_token_ids, skip_special_tokens=True)
    

    Vous pouvez vous attendre à ce que le texte détokenisé soit le suivant:

    >>> OUTPUT_TEXT
    'This page includes the show Transcript.\nUse the Transcript to help
    students with reading comprehension and vocabulary.\nAt the bottom of
    the page, comment for a chance to be mentioned on CNN Student News.
    You must be a teacher or a student age 13 or older to request a mention on the CNN Student News Roll Call.'
    
  6. Nettoyez vos conteneurs Docker et vos buckets de stockage Cloud Storage.

Aperçu de la diffusion du modèle à hôtes multiples de 175 milliards

Certains des grands modèles de langage nécessitent une tranche de TPU multi-hôte, c'est-à-dire v5litepod-16 et versions ultérieures. Dans ce cas, tous les hôtes TPU à hôtes multiples devront disposer d'une copie d'un serveur de modèles SAX, et tous les serveurs de modèles fonctionnent comme un groupe de serveurs de modèles SAX pour diffuser le modèle volumineux sur une tranche de TPU multi-hôte.

  1. Créer un cluster SAX

    Vous pouvez suivre la même étape que celle consistant à créer un cluster SAX dans le tutoriel GPT-J pour créer un cluster SAX et un serveur d'administration SAX.

    Si vous disposez déjà d'un cluster SAX, vous pouvez également lancer un serveur de modèle à hôtes multiples dans votre cluster SAX.

  2. Lancer un serveur de modèle SAX à plusieurs hôtes dans un cluster SAX

    Pour créer une tranche de TPU multi-hôte, utilisez la même commande que pour une tranche de TPU à hôte unique. Pour ce faire, spécifiez simplement le type d'accélérateur multi-hôte approprié:

    ACCELERATOR_TYPE=v5litepod-32
    ZONE=us-east1-c
    
    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --service-account ${SERVICE_ACCOUNT} \
      --reserved
    

    Pour extraire l'image du serveur de modèles SAX vers tous les hôtes/nœuds de calcul TPU et les lancer, procédez comme suit:

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --worker=all \
      --command="
        gcloud auth configure-docker \
          us-docker.pkg.dev
        # Pull sax model server image
        docker pull ${SAX_MODEL_SERVER_IMAGE_URL}
        # Run model server
        docker run \
          --privileged  \
          -it \
          -d \
          --rm \
          --network host \
          --name ${SAX_MODEL_SERVER_DOCKER_NAME} \
          --env SAX_ROOT=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
          ${SAX_MODEL_SERVER_IMAGE_URL} \
            --sax_cell=${SAX_CELL} \
            --port=10001 \
            --platform_chip=tpuv4 \
            --platform_topology=1x1"
    
  3. Publier le modèle sur le cluster SAX

    Cet exemple utilise un modèle LmCloudSpmd175B32Test:

    MODEL_NAME=lmcloudspmd175b32test
    MODEL_CONFIG_PATH=saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd175B32Test
    CHECKPOINT_PATH=None
    REPLICA=1
    

    Pour publier le modèle de test, procédez comme suit:

    docker run \
     ${SAX_UTIL_IMAGE_URL} \
       --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
       publish \
         ${SAX_CELL}/${MODEL_NAME} \
         ${MODEL_CONFIG_PATH} \
         ${CHECKPOINT_PATH} \
         ${REPLICA}
    
  4. Générer des résultats d'inférence

    docker run \
      ${SAX_UTIL_IMAGE_URL} \
        --sax_root=gs://${SAX_ADMIN_STORAGE_BUCKET}/sax-root \
        lm.generate \
          ${SAX_CELL}/${MODEL_NAME} \
          "Q:  Who is Harry Porter's mother? A\: "
    

    Étant donné que cet exemple utilise un modèle de test avec des pondérations aléatoires, le résultat peut ne pas être pertinent.

  5. Effectuer un nettoyage

    Arrêtez les conteneurs Docker:

    docker stop ${SAX_ADMIN_SERVER_DOCKER_NAME}
    docker stop ${SAX_MODEL_SERVER_DOCKER_NAME}
    

    Supprimez le bucket de stockage administrateur Cloud Storage et tout bucket de stockage de données à l'aide de gsutil, comme indiqué ci-dessous.

    gsutil rm -rf gs://${SAX_ADMIN_STORAGE_BUCKET}
    gsutil rm -rf gs://${SAX_DATA_STORAGE_BUCKET}