Présentation de Trillium (v6e)
Dans cette documentation, l'API TPU et les journaux, v6e est utilisé pour désigner Trillium. v6e représente la 6e génération de TPU de Google.
Avec 256 puces par pod, la version v6e présente de nombreuses similitudes avec la version v5e. Ce système est optimisé pour être le produit le plus intéressant pour l'entraînement, l'ajustement et le traitement des transformateurs, du texte en image et des réseaux de neurones convolutifs (CNN).
Architecture du système v6e
Pour en savoir plus sur la configuration de Cloud TPU, consultez la documentation sur la version v6e.
Ce document se concentre sur le processus de configuration de l'entraînement de modèle à l'aide des frameworks JAX, PyTorch ou TensorFlow. Avec chaque framework, vous pouvez provisionner des TPU à l'aide de ressources mises en file d'attente ou de Google Kubernetes Engine (GKE). La configuration de GKE peut être effectuée à l'aide de XPK ou de commandes GKE.
Préparer un projet Google Cloud
- Connectez-vous à votre compte Google. Si vous ne l'avez pas déjà fait, créez un compte.
- Dans la console Google Cloud, sélectionnez ou créez un projet Cloud à partir de la page de sélection du projet.
- Activez la facturation pour votre projet Google Cloud. La facturation est obligatoire pour toute utilisation de Google Cloud.
- Installez les composants gcloud alpha.
Exécutez la commande suivante pour installer la dernière version des composants
gcloud
.gcloud components update
Activez l'API TPU à l'aide de la commande
gcloud
suivante dans Cloud Shell. Vous pouvez également l'activer à partir de la console Google Cloud.gcloud services enable tpu.googleapis.com
Activer les autorisations avec le compte de service TPU pour l'API Compute Engine
Les comptes de service permettent au service Cloud TPU d'accéder à d'autres services Google Cloud. Un compte de service géré par l'utilisateur est une bonne pratique Google Cloud. Suivez ces guides pour créer et accorder des rôles. Les rôles suivants sont nécessaires:
- Administrateur TPU
- Administrateur de l'espace de stockage
- Rédacteur de journaux
- Rédacteur de métriques Monitoring
a. Configurez les autorisations XPK avec votre compte utilisateur pour GKE : XPK.
Créez des variables d'environnement pour l'ID et la zone du projet.
gcloud auth login gcloud config set project ${PROJECT_ID} gcloud config set compute/zone ${ZONE}
Créez une identité de service pour la VM TPU.
gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
Sécuriser la capacité
Contactez l'équipe commerciale/compte d'assistance Cloud TPU pour demander un quota TPU et répondre à vos questions sur la capacité.
Provisionner l'environnement Cloud TPU
Les TPU v6e peuvent être provisionnés et gérés avec GKE, avec GKE et XPK (un outil de CLI wrapper sur GKE), ou en tant que ressources mises en file d'attente.
Prérequis
- Vérifiez que votre projet dispose d'un quota
TPUS_PER_TPU_FAMILY
suffisant, qui spécifie le nombre maximal de puces auxquelles vous pouvez accéder dans votre projet Google Cloud. - La version v6e a été testée avec la configuration suivante :
- Python
3.10
ou version ultérieure - Versions logicielles nocturnes :
- JAX
0.4.32.dev20240912
par nuit - LibTPU
0.1.dev20240912+nightly
nightly
- JAX
- Versions logicielles stables :
- JAX + JAX Lib de la version 0.4.35
- Python
- Vérifiez que votre projet dispose d'un quota TPU suffisant pour :
- Quota de VM TPU
- Quota d'adresses IP
- Quota Hyperdisk-balance
- Autorisations de projet utilisateur
- Si vous utilisez GKE avec XPK, consultez la section Autorisations de la console Cloud sur le compte utilisateur ou de service pour connaître les autorisations requises pour exécuter XPK.
Variables d'environnement
Dans Cloud Shell, créez les variables d'environnement suivantes:
export NODE_ID=TPU_NODE_ID # TPU name export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=v6e-16 export ZONE=us-central2-b export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID export VALID_DURATION=VALID_DURATION # Additional environment variable needed for Multislice: export NUM_SLICES=NUM_SLICES # Use a custom network for better performance as well as to avoid having the # default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Description des options de commande
Variable | Description |
NODE_ID | ID attribué par l'utilisateur du TPU créé lors de l'allocation de la requête de ressources mise en file d'attente. |
PROJECT_ID | Nom du projet Google Cloud Utilisez un projet existant ou créez-en un sur . |
ZONE | Pour connaître les zones compatibles, consultez le document Régions et zones TPU. |
ACCELERATOR_TYPE | Consultez la section Types d'accélérateurs. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Il s'agit de l'adresse e-mail de votre compte de service, que vous pouvez trouver dans Google Cloud Console -> IAM -> Comptes de service.
Exemple: tpu-service-account@<votre_ID_de_projet>.iam.gserviceaccount.com.com |
NUM_SLICES | Nombre de tranches à créer (nécessaire pour Multislice uniquement) |
QUEUED_RESOURCE_ID | ID de texte attribué par l'utilisateur à la requête de ressource mise en file d'attente. |
VALID_DURATION | Durée de validité de la requête de ressource mise en file d'attente. |
NETWORK_NAME | Nom d'un réseau secondaire à utiliser. |
NETWORK_FW_NAME | Nom d'un pare-feu réseau secondaire à utiliser. |
Optimisations des performances réseau
Pour de meilleures performances,utilisez un réseau avec une MTU (unité de transmission maximale) de 8 896.
Par défaut, un cloud privé virtuel (VPC) ne fournit qu'une MTU de 1 460 octets,ce qui offre des performances réseau non optimales. Vous pouvez définir la MTU d'un réseau VPC sur n'importe quelle valeur comprise entre 1 300 octets et 8 896 octets (inclus). Les tailles de MTU personnalisées les plus courantes sont de 1 500 octets (Ethernet standard) ou de 8 896 octets (maximum possible). Pour en savoir plus, consultez la section Tailles de MTU valides pour le réseau VPC.
Pour en savoir plus sur la modification du paramètre MTU d'un réseau existant ou par défaut, consultez la page Modifier le paramètre de MTU d'un réseau VPC.
L'exemple suivant crée un réseau avec 8 896 MTU.
export RESOURCE_NAME=RESOURCE_NAME export NETWORK_NAME=${RESOURCE_NAME} export NETWORK_FW_NAME=${RESOURCE_NAME} export PROJECT=X gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} \
Utilisation de plusieurs cartes d'interface réseau (option pour Multislice)
Les variables d'environnement suivantes sont nécessaires pour un sous-réseau secondaire lorsque vous utilisez un environnement multicouche.
export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2
Utilisez les commandes suivantes pour créer un routage IP personnalisé pour le réseau et le sous-réseau.
gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
--bgp-routing-mode=regional --subnet-mode=custom --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
--network="${NETWORK_NAME_2}" \
--range=10.10.0.0/18 --region="${REGION}" \
--project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
--network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \
--project="${PROJECT}" \
--network="${NETWORK_NAME_2}" \
--region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
--router="${ROUTER_NAME}" \
--region="${REGION}" \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project="${PROJECT}" \
--enable-logging
Une fois qu'une tranche multiréseau a été créée, vous pouvez vérifier que les deux NIC sont utilisées en exécutant --command ifconfig
dans le cadre de la charge de travail XPK. Ensuite, examinez la sortie imprimée de cette charge de travail XPK dans les journaux de la console Cloud et vérifiez que eth0 et eth1 ont mtu=8896.
python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ (--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}) \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command "ifconfig"
Vérifiez que eth0 et eth1 ont mtu=8 896. Pour vérifier que le multi-NIC est en cours d'exécution,exécutez la commande --command "ifconfig" dans le cadre de la charge de travail XPK. Examinez ensuite la sortie imprimée de cette charge de travail xpk dans les journaux de la console Cloud, et vérifiez que eth0 et eth1 ont mtu=8896.
Amélioration des paramètres TCP
Pour les TPU créés à l'aide de l'interface de ressources en file d'attente, vous pouvez exécuter la commande suivante pour améliorer les performances réseau en modifiant les paramètres TCP par défaut pour rto_min
et quickack
.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "$PROJECT" --zone "${ZONE}" \ --command='ip route show | while IFS= read -r route; do if ! echo $route | \ grep -q linkdown; then sudo ip route change ${route/lock/} rto_min 5ms quickack 1; fi; done' \ --worker=all
Provisionnement avec des ressources en file d'attente (API Cloud TPU)
La capacité peut être provisionnée à l'aide de la commande create
de la ressource en file d'attente.
Créez une requête de ressource en file d'attente TPU.
L'indicateur
--reserved
n'est nécessaire que pour les ressources réservées, et non pour les ressources à la demande.gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ [--reserved]
Si la requête de ressources mise en file d'attente est créée, l'état du champ "response" est "WAITING_FOR_RESOURCES" ou "FAILED". Si la requête de ressource mise en file d'attente est à l'état "WAITING_FOR_RESOURCES", la ressource mise en file d'attente a été mise en file d'attente et sera provisionnée lorsqu'il y aura suffisamment de capacité TPU. Si la requête de ressource mise en file d'attente est à l'état "FAILED" (ÉCHEC), le motif de l'échec s'affiche dans la sortie. La requête de ressource mise en file d'attente expire si un v6e n'est pas provisionné dans la durée spécifiée, et son état devient "FAILED" (ÉCHEC). Pour en savoir plus, consultez la documentation publique sur les ressources mises en file d'attente.
Lorsque votre demande de ressources en file d'attente est à l'état "ACTIVE", vous pouvez vous connecter à vos VM TPU à l'aide de SSH. Utilisez les commandes
list
oudescribe
pour interroger l'état de votre ressource mise en file d'attente.gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE}
Lorsque la ressource mise en file d'attente est dans l'état "ACTIVE", le résultat ressemble à ce qui suit:
state: state: ACTIVE
Gérez vos VM TPU. Pour connaître les options de gestion de vos VM TPU, consultez la section Gérer les VM TPU.
Se connecter à vos VM TPU à l'aide de SSH
Vous pouvez installer des binaires sur chaque VM TPU de votre tranche TPU et exécuter du code. Consultez la section Types de VM pour déterminer le nombre de VM de votre tranche.
Pour installer les binaires ou exécuter du code, vous pouvez utiliser SSH pour vous connecter à une VM à l'aide de la commande
tpu-vm ssh
.gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --node=all # add this flag if you are using Multislice
Pour vous connecter à une VM spécifique via SSH, utilisez l'indicateur
--worker
, qui suit un indice basé sur 0:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
Si les formes de tranche contiennent plus de huit chips, vous aurez plusieurs VM dans une même tranche. Dans ce cas, utilisez les paramètres
--worker=all
et--command
dans votre commandegcloud alpha compute tpus tpu-vm ssh
pour exécuter une commande sur toutes les VM simultanément. Exemple :gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \ --zone ${ZONE} --worker=all \ --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Supprimer une ressource en file d'attente
Supprimez une ressource en file d'attente à la fin de la session ou supprimez les requêtes de ressources en file d'attente qui sont à l'état "FAILED". Pour supprimer une ressource en file d'attente, supprimez la tranche, puis la requête de ressource en file d'attente en deux étapes:
gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \ --zone=${ZONE} --quiet gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet
gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
Utiliser GKE avec v6e
Si vous utilisez des commandes GKE avec la version v6e, vous pouvez utiliser des commandes Kubernetes ou XPK pour provisionner des TPU et entraîner ou diffuser des modèles. Consultez Planifier des TPU dans GKE pour découvrir comment utiliser GKE avec des TPU et v6e.
Configuration du framework
Cette section décrit le processus de configuration général pour l'entraînement de modèles de ML à l'aide des frameworks JAX, PyTorch ou TensorFlow. Vous pouvez provisionner des TPU à l'aide de ressources en file d'attente ou de GKE. La configuration de GKE peut être effectuée à l'aide de commandes XPK ou Kubernetes.
Configurer JAX à l'aide de ressources en file d'attente
Installez JAX sur toutes les VM TPU de votre ou vos tranches simultanément à l'aide de gcloud alpha compute tpus tpu-vm ssh
. Pour Multislice, ajoutez --node=all
.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'
Vous pouvez exécuter le code Python suivant pour vérifier le nombre de cœurs TPU disponibles dans votre tranche et pour vérifier que tout est correctement installé (les sorties affichées ici ont été produites avec une tranche v6e-16):
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
Le résultat ressemble à ce qui suit :
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count() indique le nombre total de puces dans la tranche donnée. jax.local_device_count() indique le nombre de puces accessibles par une seule VM dans cette tranche.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 &&
&& pip install -r requirements.txt && pip install . '
Résoudre les problèmes de configuration JAX
En règle générale, nous vous recommandons d'activer la journalisation détaillée dans le fichier manifeste de votre charge de travail GKE. Fournissez ensuite les journaux à l'assistance GKE.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Messages d'erreur
no endpoints available for service 'jobset-webhook-service'
Cette erreur signifie que le jobset n'a pas été installé correctement. Vérifiez si les pods Kubernetes de déploiement jobset-controller-manager sont en cours d'exécution. Pour en savoir plus, consultez la documentation de dépannage JobSet.
TPU initialization failed: Failed to connect
Assurez-vous que votre nœud GKE utilise la version 1.30.4-gke.1348000 ou une version ultérieure (GKE 1.31 n'est pas compatible).
Configuration pour PyTorch
Cette section explique comment commencer à utiliser PJRT sur la version v6e avec PyTorch/XLA. La version recommandée est Python 3.10.
Configurer PyTorch à l'aide de GKE avec XPK
Vous pouvez utiliser le conteneur Docker suivant avec XPK, qui contient déjà les dépendances PyTorch:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Pour créer une charge de travail XPK, utilisez la commande suivante:
python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -$NUM_SLICES \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES} \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
L'utilisation de --base-docker-image
crée une nouvelle image Docker avec le répertoire de travail actuel intégré au nouveau Docker.
Configurer PyTorch à l'aide de ressources en file d'attente
Suivez ces étapes pour installer PyTorch à l'aide de ressources mises en file d'attente et exécuter un petit script sur la version 6e.
Installez les dépendances à l'aide de SSH pour accéder aux VM.
Pour Multislice, ajoutez --node=all
:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt install -y libopenblas-base pip3 \
install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Améliorer les performances des modèles avec des allocations importantes et fréquentes
Pour les modèles comportant des allocations fréquentes et importantes, nous avons constaté que l'utilisation de tcmalloc
améliore considérablement les performances par rapport à l'implémentation par défaut de malloc
. Par conséquent, la valeur par défaut de malloc
utilisée sur la VM TPU est tcmalloc
. Toutefois, en fonction de votre charge de travail (par exemple, avec DLRM qui dispose d'allocations très importantes pour ses tables d'imbrication), tcmalloc
peut entraîner un ralentissement. Dans ce cas, vous pouvez essayer de réinitialiser la variable suivante à l'aide de malloc
par défaut:
unset LD_PRELOAD
Utilisez un script Python pour effectuer un calcul sur la VM v6e:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--project ${PROJECT_ID} \
--zone ${ZONE} --worker all --command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'
Un résultat semblable à celui-ci doit s'afficher :
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
Configuration pour TensorFlow
Pour la version Preview publique de la version 6e, seule la version d'exécution tf-nightly est compatible.
Vous pouvez réinitialiser tpu-runtime
avec la version TensorFlow compatible avec la version v6e en exécutant les commandes suivantes:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Utilisez SSH pour accéder à worker-0:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE}
Installez TensorFlow sur worker-0:
sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Exportez la variable d'environnement TPU_NAME
:
export TPU_NAME=v6e-16
Vous pouvez exécuter le script Python suivant pour vérifier le nombre de cœurs TPU disponibles dans votre tranche et pour vérifier que tout est correctement installé (les sorties affichées ont été produites avec une tranche v6e-16):
import TensorFlow as tf
print("TensorFlow version " + tf.__version__)
@tf.function
def add_fn(x,y):
z = x + y
return z
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
x = tf.constant(1.)
y = tf.constant(1.)
z = strategy.run(add_fn, args=(x,y))
print(z)
Le résultat ressemble à ce qui suit :
PerReplica:{
0: tf.Tensor(2.0, shape=(), dtype=float32),
1: tf.Tensor(2.0, shape=(), dtype=float32),
2: tf.Tensor(2.0, shape=(), dtype=float32),
3: tf.Tensor(2.0, shape=(), dtype=float32),
4: tf.Tensor(2.0, shape=(), dtype=float32),
5: tf.Tensor(2.0, shape=(), dtype=float32),
6: tf.Tensor(2.0, shape=(), dtype=float32),
7: tf.Tensor(2.0, shape=(), dtype=float32)
}
v6e avec SkyPilot
Vous pouvez utiliser TPU v6e avec SkyPilot. Suivez les étapes ci-dessous pour ajouter des informations sur la localisation/la tarification liées à la v6e à SkyPilot.
Ajoutez les éléments suivants à la fin de
~/.sky/catalogs/v5/gcp/vms.csv
:,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
Spécifiez les ressources suivantes dans un fichier YAML:
# tpu_v6.yaml resources: accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use accelerator_args: runtime_version: v2-alpha-tpuv6e # Official suggested runtime
Lancez un cluster avec TPU v6e:
sky launch tpu_v6.yaml -c tpu_v6
Connectez-vous au TPU v6e à l'aide de SSH:
ssh tpu_v6
Tutoriels sur l'inférence
Les sections suivantes fournissent des tutoriels sur la diffusion de modèles MaxText et PyTorch à l'aide de JetStream, ainsi que sur la diffusion de modèles MaxDiffusion sur TPU v6e.
MaxText sur JetStream
Ce tutoriel explique comment utiliser JetStream pour diffuser des modèles MaxText (JAX) sur TPU v6e. JetStream est un moteur optimisé pour le débit et la mémoire pour l'inférence de grands modèles de langage (LLM) sur les appareils XLA (TPU). Dans ce tutoriel, vous allez exécuter le benchmark d'inférence pour le modèle Llama2-7B.
Avant de commencer
Créez un TPU v6e avec quatre puces:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Connectez-vous au TPU à l'aide de SSH:
gcloud compute tpus tpu-vm ssh TPU_NAME
Exécuter le tutoriel
Pour configurer JetStream et MaxText, convertir les points de contrôle du modèle et exécuter le benchmark d'inférence, suivez les instructions du dépôt GitHub.
Effectuer un nettoyage
Supprimez le TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async
vLLM sur PyTorch TPU
Vous trouverez ci-dessous un tutoriel simple qui vous explique comment commencer à utiliser vLLM sur une VM TPU. Pour notre exemple de bonnes pratiques de déploiement de vLLM sur Trillium en production, nous publierons un guide de l'utilisateur GKE dans les prochains jours (restez à l'écoute).
Avant de commencer
Créez un TPU v6e avec quatre puces:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Description des options de commande
Variable Description NODE_ID ID attribué par l'utilisateur du TPU créé lors de l'allocation de la requête de ressource mise en file d'attente. PROJECT_ID Nom du projet Google Cloud Utilisez un projet existant ou créez-en un sur . ZONE Pour connaître les zones compatibles, consultez le document Régions et zones TPU. ACCELERATOR_TYPE Consultez la section Types d'accélérateurs. RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Il s'agit de l'adresse e-mail de votre compte de service, que vous pouvez trouver dans Google Cloud Console -> IAM -> Comptes de service. Exemple: tpu-service-account@<votre_ID_de_projet>.iam.gserviceaccount.com.com
Connectez-vous au TPU à l'aide de SSH:
gcloud compute tpus tpu-vm ssh TPU_NAME
Create a Conda environment
(Recommended) Create a new conda environment for vLLM:
conda create -n vllm python=3.10 -y conda activate vllm
Configurer vLLM sur un TPU
Clonez le dépôt vLLM et accédez au répertoire vLLM:
git clone https://github.com/vllm-project/vllm.git && cd vllm
Nettoyez les packages torch et torch-xla existants:
pip uninstall torch torch-xla -y
Installez PyTorch et PyTorch XLA:
pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
Installez JAX et Pallas:
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Installez les autres dépendances de compilation:
pip install -r requirements-tpu.txt VLLM_TARGET_DEVICE="tpu" python setup.py develop sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
Accéder au modèle
Vous devez signer le contrat de consentement pour utiliser la famille de modèles Llama3 dans le dépôt HuggingFace.
Générez un nouveau jeton Hugging Face si vous n'en possédez pas déjà un :
- Cliquez sur Your Profile > Settings > Access Tokens (Votre profil > Paramètres > Jetons d'accès).
- Sélectionnez New Token (Nouveau jeton).
- Spécifiez le nom de votre choix et un rôle d'au moins
Read
. - Sélectionnez Générer un jeton.
Copiez le jeton généré dans votre presse-papiers, définissez-le en tant que variable d'environnement et authentifiez-vous avec huggingface-cli:
export TOKEN='' git config --global credential.helper store huggingface-cli login --token $TOKEN
Télécharger les données de benchmark
Créez un répertoire /data et téléchargez l'ensemble de données ShareGPT depuis Hugging Face.
mkdir ~/data && cd ~/data wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
Lancer le serveur vLLM
La commande suivante télécharge les poids du modèle à partir du hub de modèles Hugging Face dans le répertoire /tmp de la VM TPU, précompile une plage de formes d'entrée et écrit la compilation du modèle dans ~/.cache/vllm/xla_cache
.
Pour en savoir plus, consultez la documentation sur vLLM.
cd ~/vllm
vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &
Exécuter des benchmarks vLLM
Exécutez le script de benchmark vLLM:
python benchmarks/benchmark_serving.py \
--backend vllm \
--model "meta-llama/Meta-Llama-3.1-8B" \
--dataset-name sharegpt \
--dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json \
--num-prompts 1000
Effectuer un nettoyage
Supprimez le TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async
PyTorch sur JetStream
Ce tutoriel explique comment utiliser JetStream pour diffuser des modèles PyTorch sur TPU v6e. JetStream est un moteur optimisé pour le débit et la mémoire pour l'inférence de grands modèles de langage (LLM) sur les appareils XLA (TPU). Dans ce tutoriel, vous allez exécuter le benchmark d'inférence pour le modèle Llama2-7B.
Avant de commencer
Créez un TPU v6e avec quatre puces:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Connectez-vous au TPU à l'aide de SSH:
gcloud compute tpus tpu-vm ssh TPU_NAME
Exécuter le tutoriel
Pour configurer JetStream-PyTorch, convertir les points de contrôle du modèle et exécuter le benchmark d'inférence, suivez les instructions du dépôt GitHub.
Effectuer un nettoyage
Supprimez le TPU:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--force \
--async
Inférence MaxDiffusion
Ce tutoriel explique comment diffuser des modèles MaxDiffusion sur TPU v6e. Dans ce tutoriel, vous allez générer des images à l'aide du modèle Stable Diffusion XL.
Avant de commencer
Créez un TPU v6e avec quatre puces:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Connectez-vous au TPU à l'aide de SSH:
gcloud compute tpus tpu-vm ssh TPU_NAME
Créer un environnement Conda
Créez un répertoire pour Miniconda:
mkdir -p ~/miniconda3
Téléchargez le script d'installation de Miniconda:
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
Installez Miniconda:
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
Supprimez le script d'installation Miniconda:
rm -rf ~/miniconda3/miniconda.sh
Ajoutez Miniconda à votre variable
PATH
:export PATH="$HOME/miniconda3/bin:$PATH"
Actualisez
~/.bashrc
pour appliquer les modifications à la variablePATH
:source ~/.bashrc
Créez un environnement Conda:
conda create -n tpu python=3.10
Activez l'environnement Conda:
source activate tpu
Configurer MaxDiffusion
Clonez le dépôt MaxDiffusion et accédez au répertoire MaxDiffusion:
https://github.com/google/maxdiffusion.git && cd maxdiffusion
Passez à la branche
mlperf-4.1
:git checkout mlperf4.1
Installez MaxDiffusion:
pip install -e .
Installez les dépendances :
pip install -r requirements.txt
Installez JAX:
pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Générer des images
Définissez des variables d'environnement pour configurer l'environnement d'exécution TPU:
LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
Générez des images à l'aide de la requête et des configurations définies dans
src/maxdiffusion/configs/base_xl.yml
:python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"
Effectuer un nettoyage
Supprimez le TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async
Tutoriels de formation
Les sections suivantes fournissent des tutoriels pour entraîner MaxText.
Modèles MaxDiffusion et PyTorch sur TPU v6e.
MaxText et MaxDiffusion
Les sections suivantes couvrent le cycle de vie de l'entraînement des modèles MaxText et MaxDiffusion.
En général, les étapes de haut niveau sont les suivantes:
- Créez l'image de base de la charge de travail.
- Exécutez votre charge de travail à l'aide de XPK.
- Créez la commande d'entraînement pour la charge de travail.
- Déployez la charge de travail.
- Suivez la charge de travail et affichez les métriques.
- Supprimez la charge de travail XPK si elle n'est pas nécessaire.
- Supprimez le cluster XPK lorsqu'il n'est plus nécessaire.
Créer une image de base
Installez MaxText ou MaxDiffusion, puis créez l'image Docker:
Clonez le dépôt que vous souhaitez utiliser et accédez au répertoire du dépôt:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
Configurez Docker pour qu'il utilise la Google Cloud CLI:
gcloud auth configure-docker
Créez l'image Docker à l'aide de la commande suivante ou à l'aide de la pile stable JAX. Pour en savoir plus sur la pile stable JAX, consultez Créer une image Docker avec la pile stable JAX.
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
Si vous lancez la charge de travail à partir d'une machine sur laquelle l'image n'est pas compilée localement, importez-la:
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
Créer une image Docker avec la pile stable JAX
Vous pouvez créer les images Docker MaxText et MaxDiffusion à l'aide de l'image de base de la pile stable JAX.
La pile stable JAX fournit un environnement cohérent pour MaxText et MaxDiffusion en regroupant JAX avec des paquets de base tels que orbax
, flax
et optax
, ainsi qu'une libtpu.so bien qualifiée qui gère les utilitaires de programme TPU et d'autres outils essentiels. Ces bibliothèques sont testées pour garantir la compatibilité, ce qui fournit une base stable pour la compilation et l'exécution de MaxText et de MaxDiffusion, et élimine les conflits potentiels dus à des versions de packages incompatibles.
La pile stable JAX inclut une libtpu.so entièrement publiée et qualifiée, la bibliothèque principale qui gère la compilation, l'exécution et la configuration du réseau ICI des programmes TPU. La version libtpu remplace le build quotidien précédemment utilisé par JAX et garantit la fonctionnalité cohérente des calculs XLA sur le TPU avec des tests de qualification au niveau PJRT dans les IR HLO/StableHLO.
Pour créer l'image Docker MaxText et MaxDiffusion avec la pile stable JAX, lorsque vous exécutez le script docker_build_dependency_image.sh
, définissez la variable MODE
sur stable_stack
et la variable BASEIMAGE
sur l'image de base que vous souhaitez utiliser.
L'exemple suivant spécifie us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
comme image de base:
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
Pour obtenir la liste des images de base de la pile stable JAX disponibles, consultez la section Images de la pile stable JAX dans Artifact Registry.
Exécuter votre charge de travail à l'aide de XPK
Définissez les variables d'environnement suivantes si vous n'utilisez pas les valeurs par défaut définies par MaxText ou MaxDiffusion:
BASE_OUTPUT_DIR=gs://YOUR_BUCKET PER_DEVICE_BATCH_SIZE=2 NUM_STEPS=30 MAX_TARGET_LENGTH=8192
Créez votre script de modèle à copier en tant que commande d'entraînement à l'étape suivante. N'exécutez pas encore le script du modèle.
MaxText
MaxText est un LLM Open Source hautes performances et hautement évolutif, écrit en Python et JAX purs, et ciblant les TPU et GPU Google Cloud pour l'entraînement et l'inférence.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \ base_output_directory=$BASE_OUTPUT_DIR \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS}" # attention='dot_product'"
Gemma2
Gemma est une famille de grands modèles de langage (LLM) à poids ouverts développés par Google DeepMind, basés sur la recherche et la technologie Gemini.
# Requires v6e-256 python3 MaxText/train.py MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral est un modèle d'IA de pointe développé par Mistral AI, qui utilise une architecture MoE (Mixture of Experts) sporadique.
python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama est une famille de grands modèles de langage (LLM) à pondération ouverte développés par Meta.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=llama3-8b \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ tokenizer_path=assets/tokenizer_llama3.tiktoken \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ attention=flash"
MaxDiffusion
MaxDiffusion est un ensemble d'implémentations de référence de divers modèles de diffusion latente écrits en Python et JAX purs, qui s'exécutent sur des appareils XLA, y compris des Cloud TPU et des GPU. Stable Diffusion est un modèle de texte vers image latent qui génère des images photoréalistes à partir de n'importe quelle entrée textuelle.
Vous devez installer une branche spécifique pour exécuter MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 && pip install -r requirements.txt && pip install .
Script d'entraînement:
cd maxdiffusion && OUT_DIR=${your_own_bucket} python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml \ run_name=v6e-sd2 \ split_head_dim=True \ attention=flash \ train_new_unet=false \ norm_num_groups=16 \ output_dir=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ [dcn_data_parallelism=2] \ enable_profiler=True \ skip_first_n_steps_for_profiler=95 \ max_train_steps=${NUM_STEPS} ] write_metrics=True'
Exécutez le modèle à l'aide du script que vous avez créé à l'étape précédente. Vous devez spécifier l'indicateur
--base-docker-image
pour utiliser l'image de base MaxText ou l'indicateur--docker-image
et l'image que vous souhaitez utiliser.Facultatif: vous pouvez activer la journalisation de débogage en incluant l'option
--enable-debug-logs
. Pour en savoir plus, consultez Déboguer JAX sur MaxText.Facultatif: vous pouvez créer un Vertex AI Experiment pour importer des données dans Vertex AI TensorBoard en incluant l'indicateur
--use-vertex-tensorboard
. Pour en savoir plus, consultez Surveiller JAX sur MaxText à l'aide de Vertex AI.python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload ${USER}-xpk-ACCELERATOR_TYPE-NUM_SLICES \ --tpu-type=ACCELERATOR_TYPE \ --num-slices=NUM_SLICES \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command YOUR_MODEL_SCRIPT
Remplacez les variables suivantes :
- CLUSTER_NAME: nom de votre cluster XPK.
- ACCELERATOR_TYPE: version et taille de votre TPU. Exemple :
v6e-256
- NUM_SLICES: nombre de tranches TPU.
- YOUR_MODEL_SCRIPT: script du modèle à exécuter en tant que commande d'entraînement.
Le résultat inclut un lien permettant de suivre votre charge de travail, semblable à celui-ci:
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
Ouvrez le lien, puis cliquez sur l'onglet Journaux pour suivre votre charge de travail en temps réel.
Déboguer JAX sur MaxText
Utilisez des commandes XPK supplémentaires pour diagnostiquer pourquoi le cluster ou la charge de travail ne s'exécute pas:
- Liste des charges de travail XPK
- Inspecteur XPK
- Activez la journalisation détaillée dans vos journaux de charge de travail à l'aide de l'option
--enable-debug-logs
lorsque vous créez la charge de travail XPK.
Surveiller JAX sur MaxText à l'aide de Vertex AI
Afficher les données scalaires et de profil via TensorBoard géré par Vertex AI.
- Augmentez le nombre de requêtes de gestion des ressources (CRUD) pour la zone que vous utilisez de 600 à 5 000. Cela peut ne pas poser de problème pour les petites charges de travail utilisant moins de 16 VM.
Installez des dépendances telles que
cloud-accelerator-diagnostics
pour Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Créez votre cluster XPK à l'aide de l'option
--create-vertex-tensorboard
, comme décrit dans Créer Vertex AI TensorBoard. Vous pouvez également exécuter cette commande sur des clusters existants.Créez votre test Vertex AI lorsque vous exécutez votre charge de travail XPK à l'aide de l'indicateur
--use-vertex-tensorboard
et de l'indicateur--experiment-name
facultatif. Pour obtenir la liste complète des étapes, consultez Créer un test Vertex AI pour importer des données dans Vertex AI TensorBoard.
Les journaux incluent un lien vers un Vertex AI TensorBoard, semblable à celui-ci:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
Vous pouvez également trouver le lien Vertex AI TensorBoard dans la console Google Cloud. Accédez à Tests Vertex AI dans la console Google Cloud. Sélectionnez la région appropriée dans le menu déroulant.
Le répertoire TensorBoard est également écrit dans le bucket Cloud Storage que vous avez spécifié avec ${BASE_OUTPUT_DIR}
.
Supprimer des charges de travail XPK
Utilisez la commande xpk workload delete
pour supprimer une ou plusieurs charges de travail en fonction du préfixe ou de l'état de la tâche. Cette commande peut être utile si vous avez envoyé des charges de travail XPK qui n'ont plus besoin d'être exécutées ou si des tâches sont bloquées dans la file d'attente.
Supprimer le cluster XPK
Exécutez la commande xpk cluster delete
pour supprimer un cluster:
python3 xpk.py cluster delete --cluster CLUSTER_NAME --zone $ZONE --project $PROJECT_ID
Llama et PyTorch
Ce tutoriel explique comment entraîner des modèles Llama à l'aide de PyTorch/XLA sur TPU v6e à l'aide de l'ensemble de données WikiText. En outre, les utilisateurs peuvent accéder aux descriptions de modèles TPU PyTorch en tant qu'images Docker ici.
Installation
Installez la fourchette pytorch-tpu/transformers
de Hugging Face Transformers et les dépendances dans un environnement virtuel:
git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate
Configurer les configurations de modèle
La commande d'entraînement de la section suivante, Créer votre script de modèle, utilise deux fichiers de configuration JSON pour définir les paramètres du modèle et la configuration FSDP (Fully Sharded Data Parallel). Le fractionnement FSDP permet aux poids du modèle de s'adapter à une taille de lot plus importante lors de l'entraînement. Lorsque vous entraînez des modèles plus petits, il peut suffire d'utiliser le parallélisme des données et de répliquer les poids sur chaque appareil. Pour en savoir plus sur le fractionnement des tenseurs sur plusieurs appareils dans PyTorch/XLA, consultez le guide de l'utilisateur de SPMD PyTorch/XLA.
Créez le fichier de configuration des paramètres du modèle. Voici la configuration des paramètres du modèle pour Llama3-8B. Pour les autres modèles, recherchez la configuration sur Hugging Face. Par exemple, consultez la configuration Llama2-7B.
{ "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 }
Créez le fichier de configuration FSDP:
{ "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true }
Pour en savoir plus sur le FSDP, consultez FSDPv2.
Importez les fichiers de configuration dans vos VM TPU à l'aide de la commande suivante:
gcloud alpha compute tpus tpu-vm scp YOUR_CONFIG_FILE.json $TPU_NAME:. \ --worker=all \ --project=$PROJECT \ --zone $ZONE
Vous pouvez également créer les fichiers de configuration dans votre répertoire de travail actuel et utiliser l'indicateur
--base-docker-image
dans XPK.
Créer le script de votre modèle
Créez votre script de modèle, en spécifiant le fichier de configuration des paramètres du modèle à l'aide de l'indicateur --config_name
et le fichier de configuration FSDP à l'aide de l'indicateur --fsdp_config
.
Vous exécuterez ce script sur votre TPU dans la section suivante, Exécuter le modèle. N'exécutez pas encore le script du modèle.
PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 PROFILE_EPOCH=0 PROFILE_STEP=3 PROFILE_DURATION_MS=100000 PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 8 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/config-8B.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp_config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20
Exécuter le modèle
Exécutez le modèle à l'aide du script que vous avez créé à l'étape précédente, Créer le script de votre modèle.
Si vous utilisez une VM TPU à hôte unique (par exemple, v6e-4
), vous pouvez exécuter la commande d'entraînement directement sur la VM TPU. Si vous utilisez une VM TPU multi-hôte, utilisez la commande suivante pour exécuter le script simultanément sur tous les hôtes:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=YOUR_COMMAND
Résoudre les problèmes liés à PyTorch/XLA
Si vous définissez les variables facultatives pour le débogage dans la section précédente, le profil du modèle sera stocké à l'emplacement spécifié par la variable PROFILE_LOGDIR
. Vous pouvez extraire le fichier xplane.pb
stocké à cet emplacement et utiliser tensorboard
pour afficher les profils dans votre navigateur à l'aide des instructions TensorBoard. Si PyTorch/XLA ne fonctionne pas comme prévu, consultez le guide de dépannage, qui contient des suggestions pour le débogage, le profilage et l'optimisation de vos modèles.
Tutoriel sur le DLRM DCN v2
Ce tutoriel vous explique comment entraîner le modèle DLRM DCN v2 sur un TPU v6e.
Si vous exécutez sur plusieurs hôtes, réinitialisez tpu-runtime
avec la version TensorFlow appropriée en exécutant la commande suivante. Si vous exécutez le programme sur un seul hôte, vous n'avez pas besoin d'exécuter les deux commandes suivantes.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID}
--zone ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Se connecter en SSH à worker-0
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}
Définir le nom du TPU
export TPU_NAME=${TPU_NAME}
Exécuter DLRM v2
pip install cloud-tpu-client
pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'
TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py --mode=train --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'mixed_bfloat16'
task:
use_synthetic_data: false
use_tf_record_reader: true
train_data:
input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
global_batch_size: 16384
use_cached_data: true
validation_data:
input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
global_batch_size: 16384
use_cached_data: true
model:
num_dense_features: 13
bottom_mlp: [512, 256, 128]
embedding_dim: 128
interaction: 'multi_layer_dcn'
dcn_num_layers: 3
dcn_low_rank_dim: 512
size_threshold: 8000
top_mlp: [1024, 1024, 512, 256, 1]
use_multi_hot: true
concat_dense: false
dcn_use_bias: true
vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
max_ids_per_chip_per_sample: 128
max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
use_partial_tpu_embedding: false
size_threshold: 0
initialize_tables_on_host: true
trainer:
train_steps: 10000
validation_interval: 1000
validation_steps: 660
summary_interval: 1000
steps_per_loop: 1000
checkpoint_interval: 0
optimizer_config:
embedding_optimizer: 'Adagrad'
dense_optimizer: 'Adagrad'
lr_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.025
warmup_steps: 0
dense_sgd_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.00025
warmup_steps: 8000
train_tf_function: true
train_tf_while_loop: true
eval_tf_while_loop: true
use_orbit: true
pipeline_sparse_and_dense_execution: true"
Exécutez script.sh
:
chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Les options suivantes sont nécessaires pour exécuter des charges de travail de recommandation (DCN DLRM):
ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"
Résultats du benchmark
La section suivante contient les résultats des analyses comparatives pour le DLRM DCN v2 et MaxDiffusion sur la version 6e.
DLRM DCN v2
Le script d'entraînement DLRM DCN v2 a été exécuté à différentes échelles. Consultez les débits dans le tableau suivant.
v6e-64 | v6e-128 | v6e-256 | |
Étapes de l'entraînement | 7000 | 7000 | 7000 |
Taille du lot global | 131072 | 262144 | 524288 |
Débit (exemples/s) | 2975334 | 5111808 | 10066329 |
MaxDiffusion
Nous avons exécuté le script d'entraînement pour MaxDiffusion sur un v6e-4, un v6e-16 et un 2xv6e-16. Consultez les débits dans le tableau suivant.
v6e-4 | v6e-16 | Deux v6e-16 | |
Étapes de l'entraînement | 0,069 | 0.073 | 0,13 |
Taille du lot global | 8 | 32 | 64 |
Débit (exemples/s) | 115,9 | 438,4 | 492,3 |
Collections
La version 6e introduit une nouvelle fonctionnalité appelée "collections" pour les utilisateurs qui exécutent des charges de travail de diffusion. La fonctionnalité Collections ne s'applique qu'à la version 6e.
Les collections vous permettent d'indiquer à Google Cloud les nœuds TPU qui font partie d'une charge de travail de diffusion. Cela permet à l'infrastructure Google Cloud sous-jacente de limiter et de simplifier les interruptions pouvant être appliquées aux charges de travail d'entraînement dans le cours normal des opérations.
Utiliser des collections de l'API Cloud TPU
Une collection à hôte unique dans l'API Cloud TPU est une ressource mise en file d'attente sur laquelle un indicateur spécial (--workload-type = availability-optimized
) est défini pour indiquer à l'infrastructure sous-jacente qu'elle est destinée à servir des charges de travail.
La commande suivante provisionne une collection à hôte unique à l'aide de l'API Cloud TPU:
gcloud alpha compute tpus queued-resources create COLLECTION_NAME \ --project=project name \ --zone=zone name \ --accelerator-type=accelerator type \ --node-count=number of nodes \ --workload-type=availability-optimized
Surveiller et profiler
Cloud TPU v6e prend en charge la surveillance et le profilage à l'aide des mêmes méthodes que les générations précédentes de Cloud TPU. Pour en savoir plus sur la surveillance, consultez la page Surveiller les VM TPU.