Saxml と Vertex AI Prediction で TPU を使用して Gemma オープンモデルを提供する

このガイドでは、Saxml と Vertex AI Prediction で Tensor Processing Unit(TPU)を使用して Gemma オープンモデル大規模言語モデル(LLM)を提供する方法について説明します。このガイドでは、2B および 7B パラメータ指示でチューニングした Gemma モデルを Cloud Storage にダウンロードし、TPU で Saxml を実行する Vertex AI Prediction にデプロイします。

背景

Saxml と Vertex AI Prediction で TPU を使用して Gemma を提供します。低レベルのインフラストラクチャに対応し、LLM を費用対効果の高い方法で提供するマネージド AI ソリューションを利用できます。このセクションでは、このチュートリアルで使用されている重要なテクノロジーについて説明します。

Gemma

Gemma は、オープン ライセンスでリリースされ一般公開されている、軽量の生成 AI モデルのセットです。これらの AI モデルは、アプリケーション、ハードウェア、モバイル デバイス、ホスト型サービスで実行できます。Gemma モデルはテキスト生成に使用できますが、特殊なタスク用にチューニングすることもできます。

詳しくは、Gemma のドキュメントをご覧ください。

Saxml

Saxml は、推論に PaxmlJAXPyTorch モデルを提供する試験運用版のシステムです。このチュートリアルでは、Saxml のコスト効率の高い TPU で Gemma を提供する方法について説明します。GPU の設定も同様です。Saxml は、このチュートリアルで使用する Vertex AI Prediction のコンテナをビルドするスクリプトを提供します。

TPU

TPU は、Google が独自に開発した特定用途向け集積回路(ASIC)であり、TensorFlow、PyTorch、JAX などのデータ処理フレームワークを高速化するために使用されます。

このチュートリアルでは、Gemma 2B モデルと Gemma 7B モデルを使用します。Vertex AI Prediction は、次のシングルホスト TPU v5e ノードプールでこれらのモデルをホストします。

  • Gemma 2B: 1 つの TPU チップを表す 1x1 トポロジを持つ TPU v5e ノードプールでホストされます。ノードのマシンタイプは ct5lp-hightpu-1t です。
  • Gemma 7B: 4 つの TPU チップを表す 2x2 トポロジを持つ TPU v5e ノードプールでホストされます。ノードのマシンタイプは ct5lp-hightpu-4t です。

始める前に

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. Google Cloud Console の [プロジェクト セレクタ] ページで、Google Cloud プロジェクトを選択または作成します。

    プロジェクト セレクタに移動

  3. Google Cloud プロジェクトで課金が有効になっていることを確認します

  4. Vertex AI API and Artifact Registry API API を有効にします。

    API を有効にする

  5. Google Cloud Console の [プロジェクト セレクタ] ページで、Google Cloud プロジェクトを選択または作成します。

    プロジェクト セレクタに移動

  6. Google Cloud プロジェクトで課金が有効になっていることを確認します

  7. Vertex AI API and Artifact Registry API API を有効にします。

    API を有効にする

  8. Google Cloud コンソールで、「Cloud Shell をアクティブにする」をクリックします。

    Cloud Shell をアクティブにする

    Google Cloud コンソールの下部で Cloud Shell セッションが開始し、コマンドライン プロンプトが表示されます。Cloud Shell はシェル環境です。Google Cloud CLI がすでにインストールされており、現在のプロジェクトの値もすでに設定されています。セッションが初期化されるまで数秒かかることがあります。

このチュートリアルでは、Cloud Shell を使用して Google Cloud を操作していることを前提としています。Cloud Shell の代わりに別のシェルを使用する場合は、次の追加の構成を行います。

  1. Google Cloud CLI をインストールします。
  2. gcloud CLI を初期化するには:

    gcloud init
  3. Artifact Registry のドキュメントに従って、Docker をインストールします。
  4. Vertex AI Prediction に、5 つの TPU v5e チップに十分な割り当てがあることを確認してください。
  5. Kaggle アカウントを作成します(まだアカウントを保有されていない場合)。

モデルへのアクセス権を取得する

Vertex AI Prediction にデプロイするために Gemma モデルへのアクセス権を取得するには、Kaggle プラットフォームにログインし、ライセンス同意契約に署名して、Kaggle API トークンを入手する必要があります。このチュートリアルでは、Kaggle 認証情報に Kubernetes Secret を使用します。

Gemma を使用するには同意契約に署名する必要があります。手順は次のとおりです。

  1. Kaggle.com のモデルの同意ページにアクセスします。
  2. Kaggle にログインしていない場合はログインします。
  3. [アクセス権限をリクエスト] をクリックします。
  4. [同意に使用するアカウントを選択] セクションで、[Kaggle アカウントを使用して確認] を選択して、同意に Kaggle アカウントを使用します。
  5. モデルの利用規約に同意します。

アクセス トークンを生成する

Kaggle からモデルにアクセスするには、Kaggle API トークンが必要です。

トークンをまだ生成していない場合は、次の手順に沿って生成します。

  1. ブラウザで [Kaggle の設定] に移動します。
  2. [API] セクションで [新しいトークンを作成] をクリックします。

    kaggle.json という名前のファイルがダウンロードされます。

アクセス トークンを Cloud Shell にアップロードする

Cloud Shell で、Kaggle API トークンを Google Cloud プロジェクトにアップロードできます。

  1. Cloud Shell で、 [その他] > [アップロード] をクリックします。
  2. [ファイル] を選択し、[ファイルを選択] をクリックします。
  3. kaggle.json ファイルを開きます。
  4. [アップロード] をクリックします。

Cloud Storage バケットを作成する

モデルのチェックポイントを保存する Cloud Storage バケットを作成する。

Cloud Shell で次のコマンドを実行します。

gcloud storage buckets create gs://CHECKPOINTS_BUCKET_NAME

CHECKPOINTS_BUCKET_NAME は、モデルのチェックポイントを保存する Cloud Storage バケットの名前に置き換えます。

モデルを Cloud Storage バケットにコピーする

Cloud Shell で次のコマンドを実行します。

pip install kaggle --break-system-packages

# For Gemma 2B
mkdir -p /data/gemma_2b-it
kaggle models instances versions download google/gemma/pax/2b-it/1 --untar -p /data/gemma_2b-it
gsutil -m cp -R /data/gemma_2b-it/* gs://CHECKPOINTS_BUCKET_NAME/gemma_2b-it/

# For Gemma 7B
mkdir -p /data/gemma_7b-it
kaggle models instances versions download google/gemma/pax/7b-it/1 --untar -p /data/gemma_7b-it
gsutil -m cp -R /data/gemma_7b-it/* gs://CHECKPOINTS_BUCKET_NAME/gemma_7b-it/

Artifact Registry リポジトリを作成する

Artifact Registry リポジトリを作成して、次のセクションで作成するコンテナ イメージを保存します。

プロジェクトで Artifact Registry API サービスを有効にします。

gcloud services enable artifactregistry.googleapis.com

シェルで次のコマンドを実行して、Artifact Registry リポジトリを作成します。

gcloud artifacts repositories create saxml \
 --repository-format=docker \
 --location=LOCATION \
 --description="Saxml Docker repository"

LOCATION は、Artifact Registry がコンテナ イメージを保存するリージョンに置き換えます。後で、このリージョンと一致するリージョン エンドポイントに Vertex AI モデルリソースを作成する必要があります。そのため、Vertex AI にリージョン エンドポイントがあるリージョンを選択してください(例: TPU の us-west1)。

コンテナ イメージを Artifact Registry に push する

ビルド済みの Saxml コンテナは us-docker.pkg.dev/vertex-ai/prediction/sax-tpu:latest で入手できます。これを Artifact Registry にコピーします。Artifact Registry にアクセスできるように Docker を構成します。次に、コンテナ イメージを Artifact Registry リポジトリに push します。

  1. ローカルの Docker インストール権限を選択したリージョンの Artifact Registry に push するには、シェルで次のコマンドを実行します。

    gcloud auth configure-docker LOCATION-docker.pkg.dev
    
    • LOCATION は、リポジトリを作成したリージョンに置き換えます。
  2. ビルドしたコンテナ イメージを Artifact Registry にコピーするには、シェルで次のコマンドを実行します。

    docker tag us-docker.pkg.dev/vertex-ai/prediction/sax-tpu:latest LOCATION-docker.pkg.dev/PROJECT_ID/saxml/saxml-tpu:latest
    
  3. ビルドしたコンテナ イメージを Artifact Registry に push するには、シェルで次のコマンドを実行します。

    docker push LOCATION-docker.pkg.dev/PROJECT_ID/saxml/saxml-tpu:latest
    

    前のセクションと同様に、次のよう置き換えます。

モデルのデプロイ

モデルをアップロードする

Saxml コンテナを使用する Model リソースをアップロードするには、次の gcloud ai models upload コマンドを実行します。

Gemma 2B-it

gcloud ai models upload \
  --region=LOCATION \
  --display-name=DEPLOYED_MODEL_NAME \
  --container-image-uri=LOCATION-docker.pkg.dev/PROJECT_ID/saxml/saxml-tpu:latest \
  --artifact-uri='gs://CHECKPOINTS_BUCKET_NAME/gemma_2b-it/' \
  --container-args='--model_path=saxml.server.pax.lm.params.gemma.Gemma2BFP16' \
  --container-args='--platform_chip=tpuv5e' \
  --container-args='--platform_topology=2x2' \
  --container-args='--ckpt_path_suffix=checkpoint_00000000' \
  --container-ports=8502

Gemma 7B-it

gcloud ai models upload \
  --region=LOCATION \
  --display-name=DEPLOYED_MODEL_NAME \
  --container-image-uri=LOCATION-docker.pkg.dev/PROJECT_ID/saxml/saxml-tpu:latest \
  --artifact-uri='gs://CHECKPOINTS_BUCKET_NAME/gemma_7b-it/' \
  --container-args='--model_path=saxml.server.pax.lm.params.gemma.Gemma7BFP16' \
  --container-args='--platform_chip=tpuv5e' \
  --container-args='--platform_topology=2x2' \
  --container-args='--ckpt_path_suffix=checkpoint_00000000' \
  --container-ports=8502

次のように置き換えます。

  • PROJECT_ID: Google Cloud プロジェクトの ID
  • LOCATION_ID: Vertex AI を使用するリージョン。TPU は us-west1 でのみ使用できます。
  • DEPLOYED_MODEL_NAME: DeployedModel の名前。DeployedModelModel の表示名を使用することもできます。

エンドポイントを作成する

モデルを使用してオンライン予測を行う前に、モデルをエンドポイントにデプロイする必要があります。既存のエンドポイントにモデルをデプロイする場合は、この手順を省略できます。次の例では、gcloud ai endpoints create コマンドを使用します。

gcloud ai endpoints create \
  --region=LOCATION \
  --display-name=ENDPOINT_NAME

次のように置き換えます。

  • LOCATION_ID: Vertex AI を使用するリージョン。
  • ENDPOINT_NAME: エンドポイントの表示名。

Google Cloud CLI ツールがエンドポイントを作成するまでに数秒かかる場合があります。

モデルをエンドポイントにデプロイする

エンドポイントの準備が整ったら、モデルをエンドポイントにデプロイします。

ENDPOINT_ID=$(gcloud ai endpoints list \
   --region=LOCATION \
   --filter=display_name=ENDPOINT_NAME \
   --format="value(name)")

MODEL_ID=$(gcloud ai models list \
   --region=LOCATION \
   --filter=display_name=DEPLOYED_MODEL_NAME \
   --format="value(name)")

gcloud ai endpoints deploy-model $ENDPOINT_ID \
  --region=LOCATION \
  --model=$MODEL_ID \
  --display-name=DEPLOYED_MODEL_NAME \
  --machine-type=ct5lp-hightpu-4t \
  --traffic-split=0=100

次のように置き換えます。

  • LOCATION_ID: Vertex AI を使用するリージョン。
  • ENDPOINT_NAME: エンドポイントの表示名。
  • DEPLOYED_MODEL_NAME: DeployedModel の名前。DeployedModelModel の表示名を使用することもできます。

Gemma 2B はより小さい ct5lp-hightpu-1t マシンにデプロイできます。この場合、モデルをアップロードするときに --platform_topology=1x1 を指定する必要があります。

Google Cloud CLI ツールでエンドポイントにモデルをデプロイする場合、処理に数分かかることがあります。モデルが正常にデプロイされると、このコマンドは次の出力を返します。

  Deployed a model to the endpoint xxxxx. Id of the deployed model: xxxxx.

デプロイされたモデルからオンライン予測を提供する

Vertex AI Prediction エンドポイントからモデルを呼び出すには、標準の推論リクエストの JSON オブジェクトを使用して予測リクエストをフォーマットします。

次の例では、gcloud ai endpoints predict コマンドを使用します。

ENDPOINT_ID=$(gcloud ai endpoints list \
   --region=LOCATION \
   --filter=display_name=ENDPOINT_NAME \
   --format="value(name)")

gcloud ai endpoints predict $ENDPOINT_ID \
  --region=LOCATION \
  --http-headers=Content-Type=application/json \
  --json-request instances.json

次のように置き換えます。

  • LOCATION_ID: Vertex AI を使用するリージョン。
  • ENDPOINT_NAME: エンドポイントの表示名。
  • instances.json の形式は {"instances": [{"text_batch": "<your prompt>"},{...}]} です。

クリーンアップ

Vertex AI の料金Artifact Registry の料金が発生しないように、このチュートリアルで作成した Google Cloud リソースを削除します。

  1. エンドポイントからモデルのデプロイを解除してエンドポイントを削除するには、シェルで次のコマンドを実行します。

    ENDPOINT_ID=$(gcloud ai endpoints list \
       --region=LOCATION \
       --filter=display_name=ENDPOINT_NAME \
       --format="value(name)")
    
    DEPLOYED_MODEL_ID=$(gcloud ai endpoints describe $ENDPOINT_ID \
       --region=LOCATION \
       --format="value(deployedModels.id)")
    
    gcloud ai endpoints undeploy-model $ENDPOINT_ID \
      --region=LOCATION \
      --deployed-model-id=$DEPLOYED_MODEL_ID
    
    gcloud ai endpoints delete $ENDPOINT_ID \
       --region=LOCATION \
       --quiet
    

    LOCATION は、前のセクションでモデルを作成したリージョンに置き換えます。

  2. モデルを削除するには、シェルで次のコマンドを実行します。

    MODEL_ID=$(gcloud ai models list \
       --region=LOCATION \
       --filter=display_name=DEPLOYED_MODEL_NAME \
       --format="value(name)")
    
    gcloud ai models delete $MODEL_ID \
       --region=LOCATION \
       --quiet
    

    LOCATION は、前のセクションでモデルを作成したリージョンに置き換えます。

  3. Artifact Registry リポジトリとその中のコンテナ イメージを削除するには、シェルで次のコマンドを実行します。

    gcloud artifacts repositories delete saxml \
      --location=LOCATION \
      --quiet
    

    LOCATION は、前のセクションで Artifact Registry リポジトリを作成したリージョンに置き換えます。

制限事項

  • Vertex AI Prediction では、Cloud TPU は us-west1 でのみサポートされています。詳細については、リージョンをご覧ください。

次のステップ

  • Llama2 や GPT-J などの Saxml モデルをデプロイする方法について確認する。