Présentation de Trillium (v6e)

Dans cette documentation, l'API TPU et les journaux, v6e est utilisé pour désigner Trillium. v6e représente la 6e génération de TPU de Google.

Avec 256 puces par pod, la version v6e présente de nombreuses similitudes avec la version v5e. Ce système est optimisé pour être le produit le plus intéressant pour l'entraînement, l'ajustement et le traitement des transformateurs, du texte en image et des réseaux de neurones convolutifs (CNN).

Architecture du système v6e

Pour en savoir plus sur la configuration de Cloud TPU, consultez la documentation sur la version v6e.

Ce document se concentre sur le processus de configuration de l'entraînement de modèle à l'aide des frameworks JAX, PyTorch ou TensorFlow. Avec chaque framework, vous pouvez provisionner des TPU à l'aide de ressources mises en file d'attente ou de Google Kubernetes Engine (GKE). La configuration de GKE peut être effectuée à l'aide de XPK ou de commandes GKE.

Préparer un projet Google Cloud

  1. Connectez-vous à votre compte Google. Si vous ne l'avez pas déjà fait, créez un compte.
  2. Dans la console Google Cloud, sélectionnez ou créez un projet Cloud à partir de la page de sélection du projet.
  3. Activez la facturation pour votre projet Google Cloud. La facturation est obligatoire pour toute utilisation de Google Cloud.
  4. Installez les composants gcloud alpha.
  5. Exécutez la commande suivante pour installer la dernière version des composants gcloud.

    gcloud components update
    
  6. Activez l'API TPU à l'aide de la commande gcloud suivante dans Cloud Shell. Vous pouvez également l'activer à partir de la console Google Cloud.

    gcloud services enable tpu.googleapis.com
    
  7. Activer les autorisations avec le compte de service TPU pour l'API Compute Engine

    Les comptes de service permettent au service Cloud TPU d'accéder à d'autres services Google Cloud. Un compte de service géré par l'utilisateur est une bonne pratique Google Cloud. Suivez ces guides pour créer et accorder des rôles. Les rôles suivants sont nécessaires:

    • Administrateur TPU
    • Administrateur de l'espace de stockage
    • Rédacteur de journaux
    • Rédacteur de métriques Monitoring

    a. Configurez les autorisations XPK avec votre compte utilisateur pour GKE : XPK.

  8. Créez des variables d'environnement pour l'ID et la zone du projet.

     gcloud auth login
     gcloud config set project ${PROJECT_ID}
     gcloud config set compute/zone ${ZONE}
    
  9. Créez une identité de service pour la VM TPU.

     gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    

Sécuriser la capacité

Contactez l'équipe commerciale/compte d'assistance Cloud TPU pour demander un quota TPU et répondre à vos questions sur la capacité.

Provisionner l'environnement Cloud TPU

Les TPU v6e peuvent être provisionnés et gérés avec GKE, avec GKE et XPK (un outil de CLI wrapper sur GKE), ou en tant que ressources mises en file d'attente.

Prérequis

  • Vérifiez que votre projet dispose d'un quota TPUS_PER_TPU_FAMILY suffisant, qui spécifie le nombre maximal de puces auxquelles vous pouvez accéder dans votre projet Google Cloud.
  • La version v6e a été testée avec la configuration suivante :
    • Python 3.10 ou version ultérieure
    • Versions logicielles nocturnes :
      • JAX 0.4.32.dev20240912 par nuit
      • LibTPU 0.1.dev20240912+nightly nightly
    • Versions logicielles stables :
      • JAX + JAX Lib de la version 0.4.35
  • Vérifiez que votre projet dispose d'un quota TPU suffisant pour :
    • Quota de VM TPU
    • Quota d'adresses IP
    • Quota Hyperdisk-balance
  • Autorisations de projet utilisateur

Variables d'environnement

Dans Cloud Shell, créez les variables d'environnement suivantes:

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-central2-b
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the
# default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Description des options de commande

Variable Description
NODE_ID ID attribué par l'utilisateur du TPU créé lors de l'allocation de la requête de ressources mise en file d'attente.
PROJECT_ID Nom du projet Google Cloud Utilisez un projet existant ou créez-en un sur .
ZONE Pour connaître les zones compatibles, consultez le document Régions et zones TPU.
ACCELERATOR_TYPE Consultez la section Types d'accélérateurs.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Il s'agit de l'adresse e-mail de votre compte de service, que vous pouvez trouver dans Google Cloud Console -> IAM -> Comptes de service.

Exemple: tpu-service-account@<votre_ID_de_projet>.iam.gserviceaccount.com.com

NUM_SLICES Nombre de tranches à créer (nécessaire pour Multislice uniquement)
QUEUED_RESOURCE_ID ID de texte attribué par l'utilisateur à la requête de ressource mise en file d'attente.
VALID_DURATION Durée de validité de la requête de ressource mise en file d'attente.
NETWORK_NAME Nom d'un réseau secondaire à utiliser.
NETWORK_FW_NAME Nom d'un pare-feu réseau secondaire à utiliser.

Optimisations des performances réseau

Pour de meilleures performances,utilisez un réseau avec une MTU (unité de transmission maximale) de 8 896.

Par défaut, un cloud privé virtuel (VPC) ne fournit qu'une MTU de 1 460 octets,ce qui offre des performances réseau non optimales. Vous pouvez définir la MTU d'un réseau VPC sur n'importe quelle valeur comprise entre 1 300 octets et 8 896 octets (inclus). Les tailles de MTU personnalisées les plus courantes sont de 1 500 octets (Ethernet standard) ou de 8 896 octets (maximum possible). Pour en savoir plus, consultez la section Tailles de MTU valides pour le réseau VPC.

Pour en savoir plus sur la modification du paramètre MTU d'un réseau existant ou par défaut, consultez la page Modifier le paramètre de MTU d'un réseau VPC.

L'exemple suivant crée un réseau avec 8 896 MTU.

export RESOURCE_NAME=RESOURCE_NAME
export NETWORK_NAME=${RESOURCE_NAME}
export NETWORK_FW_NAME=${RESOURCE_NAME}
export PROJECT=X
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} \

Utilisation de plusieurs cartes d'interface réseau (option pour Multislice)

Les variables d'environnement suivantes sont nécessaires pour un sous-réseau secondaire lorsque vous utilisez un environnement multicouche.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2

Utilisez les commandes suivantes pour créer un routage IP personnalisé pour le réseau et le sous-réseau.

gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
   --bgp-routing-mode=regional --subnet-mode=custom --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 --region="${REGION}" \
   --project=$PROJECT

gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project="${PROJECT}"

gcloud compute routers create "${ROUTER_NAME}" \
  --project="${PROJECT}" \
  --network="${NETWORK_NAME_2}" \
  --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
  --router="${ROUTER_NAME}" \
  --region="${REGION}" \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project="${PROJECT}" \
  --enable-logging

Une fois qu'une tranche multiréseau a été créée, vous pouvez vérifier que les deux NIC sont utilisées en exécutant --command ifconfig dans le cadre de la charge de travail XPK. Ensuite, examinez la sortie imprimée de cette charge de travail XPK dans les journaux de la console Cloud et vérifiez que eth0 et eth1 ont mtu=8896.

python3 xpk.py workload create \
   --cluster ${CLUSTER_NAME} \
   (--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}) \
   --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone $ZONE \
   --project $PROJECT_ID \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

Vérifiez que eth0 et eth1 ont mtu=8 896. Pour vérifier que le multi-NIC est en cours d'exécution,exécutez la commande --command "ifconfig" dans le cadre de la charge de travail XPK. Examinez ensuite la sortie imprimée de cette charge de travail xpk dans les journaux de la console Cloud, et vérifiez que eth0 et eth1 ont mtu=8896.

Amélioration des paramètres TCP

Pour les TPU créés à l'aide de l'interface de ressources en file d'attente, vous pouvez exécuter la commande suivante pour améliorer les performances réseau en modifiant les paramètres TCP par défaut pour rto_min et quickack.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
   --project "$PROJECT" --zone "${ZONE}" \
   --command='ip route show | while IFS= read -r route; do if ! echo $route | \
   grep -q linkdown; then sudo ip route change ${route/lock/} rto_min 5ms quickack 1; fi; done' \
   --worker=all

Provisionnement avec des ressources en file d'attente (API Cloud TPU)

La capacité peut être provisionnée à l'aide de la commande create de la ressource en file d'attente.

  1. Créez une requête de ressource en file d'attente TPU.

    L'indicateur --reserved n'est nécessaire que pour les ressources réservées, et non pour les ressources à la demande.

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --valid-until-duration ${VALID_DURATION} \
      --service-account ${SERVICE_ACCOUNT} \
      [--reserved]

    Si la requête de ressources mise en file d'attente est créée, l'état du champ "response" est "WAITING_FOR_RESOURCES" ou "FAILED". Si la requête de ressource mise en file d'attente est à l'état "WAITING_FOR_RESOURCES", la ressource mise en file d'attente a été mise en file d'attente et sera provisionnée lorsqu'il y aura suffisamment de capacité TPU. Si la requête de ressource mise en file d'attente est à l'état "FAILED" (ÉCHEC), le motif de l'échec s'affiche dans la sortie. La requête de ressource mise en file d'attente expire si un v6e n'est pas provisionné dans la durée spécifiée, et son état devient "FAILED" (ÉCHEC). Pour en savoir plus, consultez la documentation publique sur les ressources mises en file d'attente.

    Lorsque votre demande de ressources en file d'attente est à l'état "ACTIVE", vous pouvez vous connecter à vos VM TPU à l'aide de SSH. Utilisez les commandes list ou describe pour interroger l'état de votre ressource mise en file d'attente.

    gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project ${PROJECT_ID} --zone ${ZONE}
    

    Lorsque la ressource mise en file d'attente est dans l'état "ACTIVE", le résultat ressemble à ce qui suit:

      state:
       state: ACTIVE
    
  2. Gérez vos VM TPU. Pour connaître les options de gestion de vos VM TPU, consultez la section Gérer les VM TPU.

  3. Se connecter à vos VM TPU à l'aide de SSH

    Vous pouvez installer des binaires sur chaque VM TPU de votre tranche TPU et exécuter du code. Consultez la section Types de VM pour déterminer le nombre de VM de votre tranche.

    Pour installer les binaires ou exécuter du code, vous pouvez utiliser SSH pour vous connecter à une VM à l'aide de la commande tpu-vm ssh.

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --node=all # add this flag if you are using Multislice
    

    Pour vous connecter à une VM spécifique via SSH, utilisez l'indicateur --worker, qui suit un indice basé sur 0:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
    

    Si les formes de tranche contiennent plus de huit chips, vous aurez plusieurs VM dans une même tranche. Dans ce cas, utilisez les paramètres --worker=all et --command dans votre commande gcloud alpha compute tpus tpu-vm ssh pour exécuter une commande sur toutes les VM simultanément. Exemple :

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
      --zone  ${ZONE} --worker=all \
      --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Supprimer une ressource en file d'attente

    Supprimez une ressource en file d'attente à la fin de la session ou supprimez les requêtes de ressources en file d'attente qui sont à l'état "FAILED". Pour supprimer une ressource en file d'attente, supprimez la tranche, puis la requête de ressource en file d'attente en deux étapes:

    gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \
     --zone=${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
     --project ${PROJECT_ID} --zone ${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
    

Utiliser GKE avec v6e

Si vous utilisez des commandes GKE avec la version v6e, vous pouvez utiliser des commandes Kubernetes ou XPK pour provisionner des TPU et entraîner ou diffuser des modèles. Consultez Planifier des TPU dans GKE pour découvrir comment utiliser GKE avec des TPU et v6e.

Configuration du framework

Cette section décrit le processus de configuration général pour l'entraînement de modèles de ML à l'aide des frameworks JAX, PyTorch ou TensorFlow. Vous pouvez provisionner des TPU à l'aide de ressources en file d'attente ou de GKE. La configuration de GKE peut être effectuée à l'aide de commandes XPK ou Kubernetes.

Configurer JAX à l'aide de ressources en file d'attente

Installez JAX sur toutes les VM TPU de votre ou vos tranches simultanément à l'aide de gcloud alpha compute tpus tpu-vm ssh. Pour Multislice, ajoutez --node=all.


gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'

Vous pouvez exécuter le code Python suivant pour vérifier le nombre de cœurs TPU disponibles dans votre tranche et pour vérifier que tout est correctement installé (les sorties affichées ici ont été produites avec une tranche v6e-16):

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

Le résultat ressemble à ce qui suit :

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() indique le nombre total de puces dans la tranche donnée. jax.local_device_count() indique le nombre de puces accessibles par une seule VM dans cette tranche.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 &&
   && pip install -r requirements.txt  && pip install . '

Résoudre les problèmes de configuration JAX

En règle générale, nous vous recommandons d'activer la journalisation détaillée dans le fichier manifeste de votre charge de travail GKE. Fournissez ensuite les journaux à l'assistance GKE.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Messages d'erreur

no endpoints available for service 'jobset-webhook-service'

Cette erreur signifie que le jobset n'a pas été installé correctement. Vérifiez si les pods Kubernetes de déploiement jobset-controller-manager sont en cours d'exécution. Pour en savoir plus, consultez la documentation de dépannage JobSet.

TPU initialization failed: Failed to connect

Assurez-vous que votre nœud GKE utilise la version 1.30.4-gke.1348000 ou une version ultérieure (GKE 1.31 n'est pas compatible).

Configuration pour PyTorch

Cette section explique comment commencer à utiliser PJRT sur la version v6e avec PyTorch/XLA. La version recommandée est Python 3.10.

Configurer PyTorch à l'aide de GKE avec XPK

Vous pouvez utiliser le conteneur Docker suivant avec XPK, qui contient déjà les dépendances PyTorch:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Pour créer une charge de travail XPK, utilisez la commande suivante:

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -$NUM_SLICES \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES}  \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

L'utilisation de --base-docker-image crée une nouvelle image Docker avec le répertoire de travail actuel intégré au nouveau Docker.

Configurer PyTorch à l'aide de ressources en file d'attente

Suivez ces étapes pour installer PyTorch à l'aide de ressources mises en file d'attente et exécuter un petit script sur la version 6e.

Installez les dépendances à l'aide de SSH pour accéder aux VM.

Pour Multislice, ajoutez --node=all:

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Améliorer les performances des modèles avec des allocations importantes et fréquentes

Pour les modèles comportant des allocations fréquentes et importantes, nous avons constaté que l'utilisation de tcmalloc améliore considérablement les performances par rapport à l'implémentation par défaut de malloc. Par conséquent, la valeur par défaut de malloc utilisée sur la VM TPU est tcmalloc. Toutefois, en fonction de votre charge de travail (par exemple, avec DLRM qui dispose d'allocations très importantes pour ses tables d'imbrication), tcmalloc peut entraîner un ralentissement. Dans ce cas, vous pouvez essayer de réinitialiser la variable suivante à l'aide de malloc par défaut:

unset LD_PRELOAD

Utilisez un script Python pour effectuer un calcul sur la VM v6e:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

Un résultat semblable à celui-ci doit s'afficher :

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

Configuration pour TensorFlow

Pour la version Preview publique de la version 6e, seule la version d'exécution tf-nightly est compatible.

Vous pouvez réinitialiser tpu-runtime avec la version TensorFlow compatible avec la version v6e en exécutant les commandes suivantes:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Utilisez SSH pour accéder à worker-0:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

Installez TensorFlow sur worker-0:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Exportez la variable d'environnement TPU_NAME:

export TPU_NAME=v6e-16

Vous pouvez exécuter le script Python suivant pour vérifier le nombre de cœurs TPU disponibles dans votre tranche et pour vérifier que tout est correctement installé (les sorties affichées ont été produites avec une tranche v6e-16):

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

Le résultat ressemble à ce qui suit :

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e avec SkyPilot

Vous pouvez utiliser TPU v6e avec SkyPilot. Suivez les étapes ci-dessous pour ajouter des informations sur la localisation/la tarification liées à la v6e à SkyPilot.

  1. Ajoutez les éléments suivants à la fin de ~/.sky/catalogs/v5/gcp/vms.csv :

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. Spécifiez les ressources suivantes dans un fichier YAML:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. Lancez un cluster avec TPU v6e:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. Connectez-vous au TPU v6e à l'aide de SSH: ssh tpu_v6

Tutoriels sur l'inférence

Les sections suivantes fournissent des tutoriels sur la diffusion de modèles MaxText et PyTorch à l'aide de JetStream, ainsi que sur la diffusion de modèles MaxDiffusion sur TPU v6e.

MaxText sur JetStream

Ce tutoriel explique comment utiliser JetStream pour diffuser des modèles MaxText (JAX) sur TPU v6e. JetStream est un moteur optimisé pour le débit et la mémoire pour l'inférence de grands modèles de langage (LLM) sur les appareils XLA (TPU). Dans ce tutoriel, vous allez exécuter le benchmark d'inférence pour le modèle Llama2-7B.

Avant de commencer

  1. Créez un TPU v6e avec quatre puces:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Connectez-vous au TPU à l'aide de SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Exécuter le tutoriel

Pour configurer JetStream et MaxText, convertir les points de contrôle du modèle et exécuter le benchmark d'inférence, suivez les instructions du dépôt GitHub.

Effectuer un nettoyage

Supprimez le TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

vLLM sur PyTorch TPU

Vous trouverez ci-dessous un tutoriel simple qui vous explique comment commencer à utiliser vLLM sur une VM TPU. Pour notre exemple de bonnes pratiques de déploiement de vLLM sur Trillium en production, nous publierons un guide de l'utilisateur GKE dans les prochains jours (restez à l'écoute).

Avant de commencer

  1. Créez un TPU v6e avec quatre puces:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
       --node-id TPU_NAME \
       --project PROJECT_ID \
       --zone ZONE \
       --accelerator-type v6e-4 \
       --runtime-version v2-alpha-tpuv6e \
       --service-account SERVICE_ACCOUNT

    Description des options de commande

    Variable Description
    NODE_ID ID attribué par l'utilisateur du TPU créé lors de l'allocation de la requête de ressource mise en file d'attente.
    PROJECT_ID Nom du projet Google Cloud Utilisez un projet existant ou créez-en un sur .
    ZONE Pour connaître les zones compatibles, consultez le document Régions et zones TPU.
    ACCELERATOR_TYPE Consultez la section Types d'accélérateurs.
    RUNTIME_VERSION v2-alpha-tpuv6e
    SERVICE_ACCOUNT Il s'agit de l'adresse e-mail de votre compte de service, que vous pouvez trouver dans Google Cloud Console -> IAM -> Comptes de service.

    Exemple: tpu-service-account@<votre_ID_de_projet>.iam.gserviceaccount.com.com

  2. Connectez-vous au TPU à l'aide de SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME
    

Create a Conda environment

  1. (Recommended) Create a new conda environment for vLLM:

    conda create -n vllm python=3.10 -y
    conda activate vllm

Configurer vLLM sur un TPU

  1. Clonez le dépôt vLLM et accédez au répertoire vLLM:

    git clone https://github.com/vllm-project/vllm.git && cd vllm
    
  2. Nettoyez les packages torch et torch-xla existants:

    pip uninstall torch torch-xla -y
    
  3. Installez PyTorch et PyTorch XLA:

    pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
    
  4. Installez JAX et Pallas:

    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    
    
  5. Installez les autres dépendances de compilation:

    pip install -r requirements-tpu.txt
    VLLM_TARGET_DEVICE="tpu" python setup.py develop
    sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
    

Accéder au modèle

Vous devez signer le contrat de consentement pour utiliser la famille de modèles Llama3 dans le dépôt HuggingFace.

Générez un nouveau jeton Hugging Face si vous n'en possédez pas déjà un :

  1. Cliquez sur Your Profile > Settings > Access Tokens (Votre profil > Paramètres > Jetons d'accès).
  2. Sélectionnez New Token (Nouveau jeton).
  3. Spécifiez le nom de votre choix et un rôle d'au moins Read.
  4. Sélectionnez Générer un jeton.
  5. Copiez le jeton généré dans votre presse-papiers, définissez-le en tant que variable d'environnement et authentifiez-vous avec huggingface-cli:

    export TOKEN=''
    git config --global credential.helper store
    huggingface-cli login --token $TOKEN

Télécharger les données de benchmark

  1. Créez un répertoire /data et téléchargez l'ensemble de données ShareGPT depuis Hugging Face.

    mkdir ~/data && cd ~/data
    wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
    

Lancer le serveur vLLM

La commande suivante télécharge les poids du modèle à partir du hub de modèles Hugging Face dans le répertoire /tmp de la VM TPU, précompile une plage de formes d'entrée et écrit la compilation du modèle dans ~/.cache/vllm/xla_cache.

Pour en savoir plus, consultez la documentation sur vLLM.

   cd ~/vllm
   vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &

Exécuter des benchmarks vLLM

Exécutez le script de benchmark vLLM:

   python benchmarks/benchmark_serving.py \
       --backend vllm \
       --model "meta-llama/Meta-Llama-3.1-8B"  \
       --dataset-name sharegpt \
       --dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json  \
       --num-prompts 1000

Effectuer un nettoyage

Supprimez le TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

PyTorch sur JetStream

Ce tutoriel explique comment utiliser JetStream pour diffuser des modèles PyTorch sur TPU v6e. JetStream est un moteur optimisé pour le débit et la mémoire pour l'inférence de grands modèles de langage (LLM) sur les appareils XLA (TPU). Dans ce tutoriel, vous allez exécuter le benchmark d'inférence pour le modèle Llama2-7B.

Avant de commencer

  1. Créez un TPU v6e avec quatre puces:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Connectez-vous au TPU à l'aide de SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Exécuter le tutoriel

Pour configurer JetStream-PyTorch, convertir les points de contrôle du modèle et exécuter le benchmark d'inférence, suivez les instructions du dépôt GitHub.

Effectuer un nettoyage

Supprimez le TPU:

   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --force \
      --async

Inférence MaxDiffusion

Ce tutoriel explique comment diffuser des modèles MaxDiffusion sur TPU v6e. Dans ce tutoriel, vous allez générer des images à l'aide du modèle Stable Diffusion XL.

Avant de commencer

  1. Créez un TPU v6e avec quatre puces:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Connectez-vous au TPU à l'aide de SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Créer un environnement Conda

  1. Créez un répertoire pour Miniconda:

    mkdir -p ~/miniconda3
  2. Téléchargez le script d'installation de Miniconda:

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
  3. Installez Miniconda:

    bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
  4. Supprimez le script d'installation Miniconda:

    rm -rf ~/miniconda3/miniconda.sh
  5. Ajoutez Miniconda à votre variable PATH:

    export PATH="$HOME/miniconda3/bin:$PATH"
  6. Actualisez ~/.bashrc pour appliquer les modifications à la variable PATH:

    source ~/.bashrc
  7. Créez un environnement Conda:

    conda create -n tpu python=3.10
  8. Activez l'environnement Conda:

    source activate tpu

Configurer MaxDiffusion

  1. Clonez le dépôt MaxDiffusion et accédez au répertoire MaxDiffusion:

    https://github.com/google/maxdiffusion.git && cd maxdiffusion
  2. Passez à la branche mlperf-4.1:

    git checkout mlperf4.1
  3. Installez MaxDiffusion:

    pip install -e .
  4. Installez les dépendances :

    pip install -r requirements.txt
  5. Installez JAX:

    pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Générer des images

  1. Définissez des variables d'environnement pour configurer l'environnement d'exécution TPU:

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. Générez des images à l'aide de la requête et des configurations définies dans src/maxdiffusion/configs/base_xl.yml:

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"

Effectuer un nettoyage

Supprimez le TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

Tutoriels de formation

Les sections suivantes fournissent des tutoriels pour entraîner MaxText.

Modèles MaxDiffusion et PyTorch sur TPU v6e.

MaxText et MaxDiffusion

Les sections suivantes couvrent le cycle de vie de l'entraînement des modèles MaxText et MaxDiffusion.

En général, les étapes de haut niveau sont les suivantes:

  1. Créez l'image de base de la charge de travail.
  2. Exécutez votre charge de travail à l'aide de XPK.
    1. Créez la commande d'entraînement pour la charge de travail.
    2. Déployez la charge de travail.
  3. Suivez la charge de travail et affichez les métriques.
  4. Supprimez la charge de travail XPK si elle n'est pas nécessaire.
  5. Supprimez le cluster XPK lorsqu'il n'est plus nécessaire.

Créer une image de base

Installez MaxText ou MaxDiffusion, puis créez l'image Docker:

  1. Clonez le dépôt que vous souhaitez utiliser et accédez au répertoire du dépôt:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Configurez Docker pour qu'il utilise la Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Créez l'image Docker à l'aide de la commande suivante ou à l'aide de la pile stable JAX. Pour en savoir plus sur la pile stable JAX, consultez Créer une image Docker avec la pile stable JAX.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. Si vous lancez la charge de travail à partir d'une machine sur laquelle l'image n'est pas compilée localement, importez-la:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
Créer une image Docker avec la pile stable JAX

Vous pouvez créer les images Docker MaxText et MaxDiffusion à l'aide de l'image de base de la pile stable JAX.

La pile stable JAX fournit un environnement cohérent pour MaxText et MaxDiffusion en regroupant JAX avec des paquets de base tels que orbax, flax et optax, ainsi qu'une libtpu.so bien qualifiée qui gère les utilitaires de programme TPU et d'autres outils essentiels. Ces bibliothèques sont testées pour garantir la compatibilité, ce qui fournit une base stable pour la compilation et l'exécution de MaxText et de MaxDiffusion, et élimine les conflits potentiels dus à des versions de packages incompatibles.

La pile stable JAX inclut une libtpu.so entièrement publiée et qualifiée, la bibliothèque principale qui gère la compilation, l'exécution et la configuration du réseau ICI des programmes TPU. La version libtpu remplace le build quotidien précédemment utilisé par JAX et garantit la fonctionnalité cohérente des calculs XLA sur le TPU avec des tests de qualification au niveau PJRT dans les IR HLO/StableHLO.

Pour créer l'image Docker MaxText et MaxDiffusion avec la pile stable JAX, lorsque vous exécutez le script docker_build_dependency_image.sh, définissez la variable MODE sur stable_stack et la variable BASEIMAGE sur l'image de base que vous souhaitez utiliser.

L'exemple suivant spécifie us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 comme image de base:

bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

Pour obtenir la liste des images de base de la pile stable JAX disponibles, consultez la section Images de la pile stable JAX dans Artifact Registry.

Exécuter votre charge de travail à l'aide de XPK

  1. Définissez les variables d'environnement suivantes si vous n'utilisez pas les valeurs par défaut définies par MaxText ou MaxDiffusion:

    BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    PER_DEVICE_BATCH_SIZE=2
    NUM_STEPS=30
    MAX_TARGET_LENGTH=8192
  2. Créez votre script de modèle à copier en tant que commande d'entraînement à l'étape suivante. N'exécutez pas encore le script du modèle.

    MaxText

    MaxText est un LLM Open Source hautes performances et hautement évolutif, écrit en Python et JAX purs, et ciblant les TPU et GPU Google Cloud pour l'entraînement et l'inférence.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=$BASE_OUTPUT_DIR \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}"  # attention='dot_product'"
    

    Gemma2

    Gemma est une famille de grands modèles de langage (LLM) à poids ouverts développés par Google DeepMind, basés sur la recherche et la technologie Gemini.

    # Requires v6e-256
    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral est un modèle d'IA de pointe développé par Mistral AI, qui utilise une architecture MoE (Mixture of Experts) sporadique.

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama est une famille de grands modèles de langage (LLM) à pondération ouverte développés par Meta.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash"
    

    MaxDiffusion

    MaxDiffusion est un ensemble d'implémentations de référence de divers modèles de diffusion latente écrits en Python et JAX purs, qui s'exécutent sur des appareils XLA, y compris des Cloud TPU et des GPU. Stable Diffusion est un modèle de texte vers image latent qui génère des images photoréalistes à partir de n'importe quelle entrée textuelle.

    Vous devez installer une branche spécifique pour exécuter MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    Script d'entraînement:

        cd maxdiffusion && OUT_DIR=${your_own_bucket}
        python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml \
            run_name=v6e-sd2 \
            split_head_dim=True \
            attention=flash \
            train_new_unet=false \
            norm_num_groups=16 \
            output_dir=${BASE_OUTPUT_DIR} \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            [dcn_data_parallelism=2] \
            enable_profiler=True \
            skip_first_n_steps_for_profiler=95 \
            max_train_steps=${NUM_STEPS} ]
            write_metrics=True'
        
  3. Exécutez le modèle à l'aide du script que vous avez créé à l'étape précédente. Vous devez spécifier l'indicateur --base-docker-image pour utiliser l'image de base MaxText ou l'indicateur --docker-image et l'image que vous souhaitez utiliser.

    Facultatif: vous pouvez activer la journalisation de débogage en incluant l'option --enable-debug-logs. Pour en savoir plus, consultez Déboguer JAX sur MaxText.

    Facultatif: vous pouvez créer un Vertex AI Experiment pour importer des données dans Vertex AI TensorBoard en incluant l'indicateur --use-vertex-tensorboard. Pour en savoir plus, consultez Surveiller JAX sur MaxText à l'aide de Vertex AI.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload ${USER}-xpk-ACCELERATOR_TYPE-NUM_SLICES \
        --tpu-type=ACCELERATOR_TYPE \
        --num-slices=NUM_SLICES  \
        --on-demand \
        --zone $ZONE \
        --project $PROJECT_ID \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command YOUR_MODEL_SCRIPT

    Remplacez les variables suivantes :

    • CLUSTER_NAME: nom de votre cluster XPK.
    • ACCELERATOR_TYPE: version et taille de votre TPU. Exemple :v6e-256
    • NUM_SLICES: nombre de tranches TPU.
    • YOUR_MODEL_SCRIPT: script du modèle à exécuter en tant que commande d'entraînement.

    Le résultat inclut un lien permettant de suivre votre charge de travail, semblable à celui-ci:

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    Ouvrez le lien, puis cliquez sur l'onglet Journaux pour suivre votre charge de travail en temps réel.

Déboguer JAX sur MaxText

Utilisez des commandes XPK supplémentaires pour diagnostiquer pourquoi le cluster ou la charge de travail ne s'exécute pas:

Surveiller JAX sur MaxText à l'aide de Vertex AI

Afficher les données scalaires et de profil via TensorBoard géré par Vertex AI.

  1. Augmentez le nombre de requêtes de gestion des ressources (CRUD) pour la zone que vous utilisez de 600 à 5 000. Cela peut ne pas poser de problème pour les petites charges de travail utilisant moins de 16 VM.
  2. Installez des dépendances telles que cloud-accelerator-diagnostics pour Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Créez votre cluster XPK à l'aide de l'option --create-vertex-tensorboard, comme décrit dans Créer Vertex AI TensorBoard. Vous pouvez également exécuter cette commande sur des clusters existants.

  4. Créez votre test Vertex AI lorsque vous exécutez votre charge de travail XPK à l'aide de l'indicateur --use-vertex-tensorboard et de l'indicateur --experiment-name facultatif. Pour obtenir la liste complète des étapes, consultez Créer un test Vertex AI pour importer des données dans Vertex AI TensorBoard.

Les journaux incluent un lien vers un Vertex AI TensorBoard, semblable à celui-ci:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

Vous pouvez également trouver le lien Vertex AI TensorBoard dans la console Google Cloud. Accédez à Tests Vertex AI dans la console Google Cloud. Sélectionnez la région appropriée dans le menu déroulant.

Le répertoire TensorBoard est également écrit dans le bucket Cloud Storage que vous avez spécifié avec ${BASE_OUTPUT_DIR}.

Supprimer des charges de travail XPK

Utilisez la commande xpk workload delete pour supprimer une ou plusieurs charges de travail en fonction du préfixe ou de l'état de la tâche. Cette commande peut être utile si vous avez envoyé des charges de travail XPK qui n'ont plus besoin d'être exécutées ou si des tâches sont bloquées dans la file d'attente.

Supprimer le cluster XPK

Exécutez la commande xpk cluster delete pour supprimer un cluster:

python3 xpk.py cluster delete --cluster CLUSTER_NAME --zone $ZONE --project $PROJECT_ID

Llama et PyTorch

Ce tutoriel explique comment entraîner des modèles Llama à l'aide de PyTorch/XLA sur TPU v6e à l'aide de l'ensemble de données WikiText. En outre, les utilisateurs peuvent accéder aux descriptions de modèles TPU PyTorch en tant qu'images Docker ici.

Installation

Installez la fourchette pytorch-tpu/transformers de Hugging Face Transformers et les dépendances dans un environnement virtuel:

git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets
pip3 install evaluate
pip3 install scikit-learn
pip3 install accelerate

Configurer les configurations de modèle

La commande d'entraînement de la section suivante, Créer votre script de modèle, utilise deux fichiers de configuration JSON pour définir les paramètres du modèle et la configuration FSDP (Fully Sharded Data Parallel). Le fractionnement FSDP permet aux poids du modèle de s'adapter à une taille de lot plus importante lors de l'entraînement. Lorsque vous entraînez des modèles plus petits, il peut suffire d'utiliser le parallélisme des données et de répliquer les poids sur chaque appareil. Pour en savoir plus sur le fractionnement des tenseurs sur plusieurs appareils dans PyTorch/XLA, consultez le guide de l'utilisateur de SPMD PyTorch/XLA.

  1. Créez le fichier de configuration des paramètres du modèle. Voici la configuration des paramètres du modèle pour Llama3-8B. Pour les autres modèles, recherchez la configuration sur Hugging Face. Par exemple, consultez la configuration Llama2-7B.

    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
  2. Créez le fichier de configuration FSDP:

    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }

    Pour en savoir plus sur le FSDP, consultez FSDPv2.

  3. Importez les fichiers de configuration dans vos VM TPU à l'aide de la commande suivante:

        gcloud alpha compute tpus tpu-vm scp YOUR_CONFIG_FILE.json $TPU_NAME:. \
            --worker=all \
            --project=$PROJECT \
            --zone $ZONE

    Vous pouvez également créer les fichiers de configuration dans votre répertoire de travail actuel et utiliser l'indicateur --base-docker-image dans XPK.

Créer le script de votre modèle

Créez votre script de modèle, en spécifiant le fichier de configuration des paramètres du modèle à l'aide de l'indicateur --config_name et le fichier de configuration FSDP à l'aide de l'indicateur --fsdp_config. Vous exécuterez ce script sur votre TPU dans la section suivante, Exécuter le modèle. N'exécutez pas encore le script du modèle.

    PJRT_DEVICE=TPU
    XLA_USE_SPMD=1
    ENABLE_PJRT_COMPATIBILITY=true
    # Optional variables for debugging:
    XLA_IR_DEBUG=1
    XLA_HLO_DEBUG=1
    PROFILE_EPOCH=0
    PROFILE_STEP=3
    PROFILE_DURATION_MS=100000
    PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path
    python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 8 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/config-8B.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp_config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20

Exécuter le modèle

Exécutez le modèle à l'aide du script que vous avez créé à l'étape précédente, Créer le script de votre modèle.

Si vous utilisez une VM TPU à hôte unique (par exemple, v6e-4), vous pouvez exécuter la commande d'entraînement directement sur la VM TPU. Si vous utilisez une VM TPU multi-hôte, utilisez la commande suivante pour exécuter le script simultanément sur tous les hôtes:

gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
    --zone $ZONE \
    --worker=all \
    --command=YOUR_COMMAND

Résoudre les problèmes liés à PyTorch/XLA

Si vous définissez les variables facultatives pour le débogage dans la section précédente, le profil du modèle sera stocké à l'emplacement spécifié par la variable PROFILE_LOGDIR. Vous pouvez extraire le fichier xplane.pb stocké à cet emplacement et utiliser tensorboard pour afficher les profils dans votre navigateur à l'aide des instructions TensorBoard. Si PyTorch/XLA ne fonctionne pas comme prévu, consultez le guide de dépannage, qui contient des suggestions pour le débogage, le profilage et l'optimisation de vos modèles.

Tutoriel sur le DLRM DCN v2

Ce tutoriel vous explique comment entraîner le modèle DLRM DCN v2 sur un TPU v6e.

Si vous exécutez sur plusieurs hôtes, réinitialisez tpu-runtime avec la version TensorFlow appropriée en exécutant la commande suivante. Si vous exécutez le programme sur un seul hôte, vous n'avez pas besoin d'exécuter les deux commandes suivantes.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Se connecter en SSH à worker-0

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

Définir le nom du TPU

export TPU_NAME=${TPU_NAME}

Exécuter DLRM v2

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

Exécutez script.sh :

chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Les options suivantes sont nécessaires pour exécuter des charges de travail de recommandation (DCN DLRM):

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

Résultats du benchmark

La section suivante contient les résultats des analyses comparatives pour le DLRM DCN v2 et MaxDiffusion sur la version 6e.

DLRM DCN v2

Le script d'entraînement DLRM DCN v2 a été exécuté à différentes échelles. Consultez les débits dans le tableau suivant.

v6e-64 v6e-128 v6e-256
Étapes de l'entraînement 7000 7000 7000
Taille du lot global 131072 262144 524288
Débit (exemples/s) 2975334 5111808 10066329

MaxDiffusion

Nous avons exécuté le script d'entraînement pour MaxDiffusion sur un v6e-4, un v6e-16 et un 2xv6e-16. Consultez les débits dans le tableau suivant.

v6e-4 v6e-16 Deux v6e-16
Étapes de l'entraînement 0,069 0.073 0,13
Taille du lot global 8 32 64
Débit (exemples/s) 115,9 438,4 492,3

Collections

La version 6e introduit une nouvelle fonctionnalité appelée "collections" pour les utilisateurs qui exécutent des charges de travail de diffusion. La fonctionnalité Collections ne s'applique qu'à la version 6e.

Les collections vous permettent d'indiquer à Google Cloud les nœuds TPU qui font partie d'une charge de travail de diffusion. Cela permet à l'infrastructure Google Cloud sous-jacente de limiter et de simplifier les interruptions pouvant être appliquées aux charges de travail d'entraînement dans le cours normal des opérations.

Utiliser des collections de l'API Cloud TPU

Une collection à hôte unique dans l'API Cloud TPU est une ressource mise en file d'attente sur laquelle un indicateur spécial (--workload-type = availability-optimized) est défini pour indiquer à l'infrastructure sous-jacente qu'elle est destinée à servir des charges de travail.

La commande suivante provisionne une collection à hôte unique à l'aide de l'API Cloud TPU:

gcloud alpha compute tpus queued-resources create COLLECTION_NAME \
   --project=project name \
   --zone=zone name \
   --accelerator-type=accelerator type \
   --node-count=number of nodes \
   --workload-type=availability-optimized

Surveiller et profiler

Cloud TPU v6e prend en charge la surveillance et le profilage à l'aide des mêmes méthodes que les générations précédentes de Cloud TPU. Pour en savoir plus sur la surveillance, consultez la page Surveiller les VM TPU.