Présentation de Trillium (v6e)
Dans cette documentation, l'API TPU et les journaux, v6e est utilisé pour désigner Trillium. v6e représente la 6e génération de TPU de Google.
Avec 256 puces par pod, l'architecture v6e présente de nombreuses similitudes avec la version v5e. Ce système est optimisé pour l'entraînement, l'ajustement et la diffusion des transformateurs, de la conversion texte-image et des réseaux de neurones convolutifs (CNN).
Pour en savoir plus sur l'architecture et les configurations du système v6e, consultez le document v6e.
Ce document d'introduction se concentre sur les processus d'entraînement et de diffusion des modèles à l'aide des frameworks JAX, PyTorch ou TensorFlow. Avec chaque framework, vous pouvez provisionner des TPU à l'aide de ressources en file d'attente ou de Google Kubernetes Engine (GKE). La configuration de GKE peut être effectuée à l'aide de XPK ou de commandes GKE.
Procédure générale pour entraîner ou diffuser un modèle à l'aide de la version v6e
- Préparer un Google Cloud projet
- Capacité sécurisée
- Configurer votre environnement TPU
- Provisionner l'environnement Cloud TPU
- Exécuter une charge de travail d' entraînement ou d'inférence de modèle
- Nettoyer
Préparer un Google Cloud projet
- Connectez-vous à votre compte Google. Si vous ne l'avez pas déjà fait, créez un compte.
- Dans la console Google Cloud, sélectionnez ou créez un projet Cloud à partir de la page de sélection du projet.
- Activez la facturation pour votre projet Google Cloud. La facturation est obligatoire pour toute utilisation de Google Cloud.
- Installez les composants gcloud alpha.
Exécutez la commande suivante pour installer la dernière version des composants
gcloud
.gcloud components update
Activez l'API TPU à l'aide de la commande
gcloud
suivante dans Cloud Shell. Vous pouvez également l'activer à partir de la console Google Cloud.gcloud services enable tpu.googleapis.com
Activer les autorisations avec le compte de service TPU pour l'API Compute Engine
Les comptes de service permettent au service Cloud TPU d'accéder à d'autres services Google Cloud. L'utilisation d'un compte de service géré par l'utilisateur est une bonne pratique Google Cloud. Suivez ces guides pour créer et accorder des rôles. Les rôles suivants sont nécessaires:
- Administrateur TPU
- Administrateur de l'espace de stockage
- Rédacteur de journaux
- Rédacteur de métriques Monitoring
a. Configurez les autorisations XPK avec votre compte utilisateur pour GKE : XPK.
Authentifiez-vous avec votre compte Google, puis définissez l'ID de projet et la zone par défaut.
auth login
autorise gcloud à accéder à Google Cloud avec des identifiants utilisateur Google.
PROJECT_ID
est le Google Cloud nom du projet.
ZONE
est la zone dans laquelle vous souhaitez créer le TPU.gcloud auth login gcloud config set project ${PROJECT_ID} gcloud config set compute/zone ${ZONE}
Créez une identité de service pour la VM TPU.
gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
Sécuriser la capacité
Contactez l'équipe commerciale/compte d'assistance Cloud TPU pour demander un quota TPU et répondre à vos questions sur la capacité.
Provisionner l'environnement Cloud TPU
Les TPU v6e peuvent être provisionnés et gérés avec GKE, avec GKE et XPK (un outil de CLI wrapper sur GKE), ou en tant que ressources mises en file d'attente.
Prérequis
- Vérifiez que votre projet dispose d'un quota
TPUS_PER_TPU_FAMILY
suffisant, qui spécifie le nombre maximal de puces auxquelles vous pouvez accéder dans votre projetGoogle Cloud . - La version v6e a été testée avec la configuration suivante :
- Python
3.10
ou version ultérieure - Versions logicielles nocturnes :
0.4.32.dev20240912
JAX- LibTPU
0.1.dev20240912+nightly
nightly
- Versions logicielles stables :
- JAX + JAX Lib de la version 0.4.37
- Python
Vérifiez que votre projet dispose d'un quota TPU suffisant pour:
- Quota de VM TPU
- Quota d'adresses IP
Quota Hyperdisk équilibré
Autorisations des projets utilisateur
- Si vous utilisez GKE avec XPK, consultez la section Autorisations de la console Cloud sur le compte utilisateur ou de service pour connaître les autorisations requises pour exécuter XPK.
Variables d'environnement
Dans Cloud Shell, créez les variables d'environnement suivantes:
export NODE_ID=TPU_NODE_ID # TPU name export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=v6e-16 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID export VALID_DURATION=VALID_DURATION # Additional environment variable needed for provisioning Multislice: export NUM_SLICES=NUM_SLICES # Use a custom network for better performance as well as to avoid having the default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Description des options de commande
Variable | Description |
NODE_ID | ID attribué par l'utilisateur du TPU créé lors de l'allocation de la requête de ressource mise en file d'attente. |
PROJECT_ID | Google Cloud Nom du projet. Utilisez un projet existant ou créez-en un sur . |
ZONE | Pour connaître les zones compatibles, consultez le document Régions et zones TPU. |
ACCELERATOR_TYPE | Consultez la section Types d'accélérateurs. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Il s'agit de l'adresse e-mail de votre compte de service, que vous pouvez trouver dans Google Cloud Console -> IAM -> Comptes de service.
Exemple: tpu-service-account@<votre_ID_de_projet>.iam.gserviceaccount.com.com |
NUM_SLICES | Nombre de tranches à créer (nécessaire pour Multislice uniquement). |
QUEUED_RESOURCE_ID | ID de texte attribué par l'utilisateur à la requête de ressource mise en file d'attente. |
VALID_DURATION | Durée de validité de la requête de ressource mise en file d'attente. |
NETWORK_NAME | Nom d'un réseau secondaire à utiliser. |
NETWORK_FW_NAME | Nom d'un pare-feu réseau secondaire à utiliser. |
Optimisations des performances réseau
Pour de meilleures performances,utilisez un réseau avec une MTU (unité de transmission maximale) de 8 896.
Par défaut, un cloud privé virtuel (VPC) ne fournit qu'une MTU de 1 460 octets,ce qui offre des performances réseau non optimales. Vous pouvez définir la MTU d'un réseau VPC sur n'importe quelle valeur comprise entre 1 300 octets et 8 896 octets (inclus). Les tailles de MTU personnalisées les plus courantes sont de 1 500 octets (Ethernet standard) ou de 8 896 octets (le maximum possible). Pour en savoir plus, consultez la section Tailles de MTU valides pour le réseau VPC.
Pour en savoir plus sur la modification du paramètre MTU d'un réseau existant ou par défaut, consultez la page Modifier le paramètre de MTU d'un réseau VPC.
L'exemple suivant crée un réseau avec 8 896 MTU.
export RESOURCE_NAME=RESOURCE_NAME export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall export PROJECT=X gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT}
Utiliser plusieurs cartes d'interface réseau (option pour Multislice)
Les variables d'environnement suivantes sont nécessaires pour un sous-réseau secondaire lorsque vous utilisez un environnement multicouche.
export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2
Utilisez les commandes suivantes pour créer un routage IP personnalisé pour le réseau et le sous-réseau.
gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
--bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
--network="${NETWORK_NAME_2}" \
--range=10.10.0.0/18 --region="${REGION}" \
--project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
--network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
gcloud compute routers create "${ROUTER_NAME}" \
--project="${PROJECT_ID}" \
--network="${NETWORK_NAME_2}" \
--region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
--router="${ROUTER_NAME}" \
--region="${REGION}" \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project="${PROJECT_ID}" \
--enable-logging
Une fois qu'une tranche multiréseau a été créée, vous pouvez vérifier que les deux NIC sont utilisées en configurant un cluster XPK et en exécutant --command ifconfig
dans le cadre de la charge de travail XPK.
Utilisez la commande xpk workload
suivante pour afficher la sortie de la commande ifconfig
dans les journaux de la console Cloud et vérifier que les valeurs mtu=8896 sont définies pour eth0 et eth1.
python3 xpk.py workload create \ --cluster CLUSTER_NAME \ (--base-docker-image maxtext_base_image|--docker-image CLOUD_IMAGE_NAME \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command "ifconfig"
Vérifiez que eth0 et eth1 ont mtu=8 896. Pour vérifier que le multi-NIC est en cours d'exécution,exécutez la commande --command "ifconfig" dans le cadre de la charge de travail XPK. Examinez ensuite la sortie imprimée de cette charge de travail xpk dans les journaux de la console Cloud, et vérifiez que eth0 et eth1 ont mtu=8896.
Amélioration des paramètres TCP
Pour les TPU créés à l'aide de l'interface de ressources en file d'attente, vous pouvez exécuter la commande suivante pour améliorer les performances du réseau en augmentant les limites de tampon de réception TCP.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "$PROJECT" \ --zone "$ZONE" \ --node=all \ --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \ --worker=all
Provisionnement avec des ressources en file d'attente
La capacité allouée peut être provisionnée à l'aide de la commande create
de la ressource en file d'attente.
Créez une requête de ressource en file d'attente TPU.
L'indicateur
--reserved
n'est nécessaire que pour les ressources réservées, et non pour les ressources à la demande.gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ [--reserved] # The following flags are only needed if you are using Multislice. --node-count node-count # Number of slices in a Multislice \ --node-prefix node-prefix # An optional user-defined node prefix; the default is QUEUED_RESOURCE_ID.
Si la requête de ressources mise en file d'attente est créée, l'état du champ "response" est "WAITING_FOR_RESOURCES" ou "FAILED". Si la requête de ressources en file d'attente est à l'état "WAITING_FOR_RESOURCES" (EN ATTENTE DE RESSOURCES), la ressource a été ajoutée à la file d'attente et sera provisionnée lorsqu'il y aura suffisamment de capacité TPU allouée. Si la requête de ressource mise en file d'attente est à l'état "FAILED" (ÉCHEC), le motif de l'échec s'affiche dans la sortie. La requête de ressources mise en file d'attente expire si un v6e n'est pas provisionné dans la durée spécifiée, et son état devient "ÉCHEC". Pour en savoir plus, consultez la documentation publique sur les ressources mises en file d'attente.
Lorsque votre demande de ressources en file d'attente est à l'état "ACTIVE", vous pouvez vous connecter à vos VM TPU à l'aide de SSH. Utilisez les commandes
list
oudescribe
pour interroger l'état de votre ressource mise en file d'attente.gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE}
Lorsque la ressource mise en file d'attente est dans l'état "ACTIVE", le résultat ressemble à ce qui suit:
state: state: ACTIVE
Gérez vos VM TPU. Pour connaître les options de gestion de vos VM TPU, consultez la section Gérer les VM TPU.
Se connecter à vos VM TPU à l'aide de SSH
Vous pouvez installer des binaires sur chaque VM TPU de votre tranche TPU et exécuter du code. Consultez la section Types de VM pour déterminer le nombre de VM de votre tranche.
Pour installer les binaires ou exécuter du code, vous pouvez utiliser SSH pour vous connecter à une VM à l'aide de la commande
tpu-vm ssh
.gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --node=all # add this flag if you are using Multislice
Pour vous connecter à une VM spécifique via SSH, utilisez l'indicateur
--worker
, qui suit un indice basé sur 0:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
Si les formes de tranche contiennent plus de huit chips, vous aurez plusieurs VM dans une même tranche. Dans ce cas, utilisez les paramètres
--worker=all
et--command
dans votre commandegcloud alpha compute tpus tpu-vm ssh
pour exécuter une commande sur toutes les VM simultanément. Exemple :gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \ --zone ${ZONE} --worker=all \ --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Supprimer une ressource en file d'attente
Supprimez une ressource en file d'attente à la fin de la session ou supprimez les requêtes de ressources en file d'attente qui sont à l'état "FAILED". Pour supprimer une ressource en file d'attente, supprimez la tranche, puis la requête de ressource en file d'attente en deux étapes:
gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \ --zone=${ZONE} --quiet gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet
gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
Provisionnement de TPU v6e avec GKE ou XPK
Si vous utilisez des commandes GKE avec la version v6e, vous pouvez utiliser des commandes Kubernetes ou XPK pour provisionner des TPU et entraîner ou diffuser des modèles. Consultez Planifier des TPU dans GKE pour découvrir comment planifier vos configurations de TPU dans les clusters GKE. Les sections suivantes fournissent des commandes permettant de créer un cluster XPK avec prise en charge d'une seule carte réseau et de plusieurs cartes réseau.
Commandes permettant de créer un cluster XPK avec une seule carte réseau
export CLUSTER_NAME xpk-cluster-name export ZONE=us-central2-b export PROJECT=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network ${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster $CLUSTER_NAME \ --cluster-cpu-machine-type=n1-standard-8 \ --num-slices=$NUM_SLICES \ --tpu-type=$TPU_TYPE \ --zone=$ZONE \ --project=$PROJECT \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
Description des options de commande
Variable | Description |
CLUSTER_NAME | Nom attribué par l'utilisateur au cluster XPK. |
PROJECT_ID | Google Cloud Nom du projet. Utilisez un projet existant ou créez-en un sur . |
ZONE | Pour connaître les zones compatibles, consultez le document Régions et zones TPU. |
TPU_TYPE | Consultez la section Types d'accélérateurs. |
NUM_SLICES | Nombre de tranches que vous souhaitez créer |
CLUSTER_ARGUMENTS | Réseau et sous-réseau à utiliser.
Par exemple: "--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}" |
NUM_SLICES | Nombre de secteurs à créer. |
NETWORK_NAME | Nom d'un réseau secondaire à utiliser. |
NETWORK_FW_NAME | Nom d'un pare-feu réseau secondaire à utiliser. |
Commandes permettant de créer un cluster XPK compatible avec plusieurs NIC
export CLUSTER_NAME xpk-cluster-name export ZONE=us-central2-b export PROJECT=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export exportSUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create "${NETWORK_NAME_1}" \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_1}" \ --network="${NETWORK_NAME_1}" \ --range=10.11.0.0/18 \ --region="${REGION}" \ --project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \ --network "${NETWORK_NAME_1}" \ --allow tcp,icmp,udp \ --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \ --project="${PROJECT}" \ --network="${NETWORK_NAME_1}" \ --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \ --router="${ROUTER_NAME}" \ --region="${REGION}" \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project="${PROJECT}" \ --enable-logging
Secondary subnet for multi-nic experience. Need custom ip routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create "${NETWORK_NAME_2}" \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \ --network="${NETWORK_NAME_2}" \ --range=10.10.0.0/18 \ --region="${REGION}" \ --project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \ --network "${NETWORK_NAME_2}" \ --allow tcp,icmp,udp \ --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \ --project="${PROJECT}" \ --network="${NETWORK_NAME_2}" \ --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \ --router="${ROUTER_NAME}" \ --region="${REGION}" \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project="${PROJECT}" \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster $CLUSTER_NAME \
--num-slices=$NUM_SLICES \
--tpu-type=$TPU_TYPE \
--zone=$ZONE \
--project=$PROJECT \
--on-demand \
--custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
--custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
--create-vertex-tensorboard
Description des options de commande
Variable | Description |
CLUSTER_NAME | Nom attribué par l'utilisateur au cluster XPK. |
PROJECT_ID | Google Cloud Nom du projet. Utilisez un projet existant ou créez-en un sur . |
ZONE | Pour connaître les zones compatibles, consultez le document Régions et zones TPU. |
TPU_TYPE | Consultez la section Types d'accélérateurs. |
NUM_SLICES | Nombre de tranches que vous souhaitez créer |
CLUSTER_ARGUMENTS | Réseau et sous-réseau à utiliser.
Par exemple: "--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}" |
NODE_POOL_ARGUMENTS | Réseau de nœuds supplémentaire à utiliser.
Par exemple: "--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}" |
NUM_SLICES | Nombre de tranches à créer (nécessaire pour Multislice uniquement). |
NETWORK_NAME | Nom d'un réseau secondaire à utiliser. |
NETWORK_FW_NAME | Nom d'un pare-feu réseau secondaire à utiliser. |
Configuration du framework
Cette section décrit le processus de configuration général pour l'entraînement de modèles de ML à l'aide des frameworks JAX, PyTorch ou TensorFlow. Vous pouvez provisionner des TPU à l'aide de ressources en file d'attente ou de GKE. La configuration de GKE peut être effectuée à l'aide de commandes XPK ou Kubernetes.
Configuration pour JAX
Cette section fournit des exemples d'exécution de charges de travail JAX sur GKE, avec ou sans XPK, ainsi que d'utilisation de ressources en file d'attente.
Configurer JAX avec GKE
L'exemple suivant configure un hôte unique 2X2 à l'aide d'un fichier YAML Kubernetes.
Tranche unique sur un seul hôte
apiVersion: v1
kind: Pod
metadata:
name: tpu-pod-jax-v6e-a
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
containers:
- name: tpu-job
image: python:3.10
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Une fois l'opération terminée, le message suivant doit s'afficher dans le journal GKE:
Total TPU chips: 4
Tranche unique sur plusieurs hôtes
L'exemple suivant configure un pool de nœuds multi-hôtes 4x4 à l'aide d'un fichier YAML Kubernetes.
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-available-chips
spec:
backoffLimit: 0
completions: 4
parallelism: 4
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: tpu-job
image: python:3.10
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Une fois l'opération terminée, le message suivant doit s'afficher dans le journal GKE:
Total TPU chips: 16
Multislice sur multi-hôte
L'exemple suivant configure deux pools de nœuds multi-hôtes 4x4 à l'aide d'un fichier YAML Kubernetes.
Vous devez d'abord installer JobSet v0.2.3 ou une version ultérieure.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: multislice-job
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
maxRestarts: 4
replicatedJobs:
- name: slice
replicas: 2
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
hostNetwork: true
containers:
- name: jax-tpu
image: python:3.10
ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 8431
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
limits:
google.com/tpu: 4
requests:
google.com/tpu: 4
Une fois l'opération terminée, le message suivant doit s'afficher dans le journal GKE:
Total TPU chips: 32
Pour en savoir plus, consultez la section Exécuter une charge de travail multicouche dans la documentation de GKE.
Pour améliorer les performances, activez hostNetwork.
Multi-NIC
Pour profiter de la multi-NIC dans GKE, le fichier manifeste du pod Kubernetes doit comporter des annotations supplémentaires. Voici un exemple de fichier manifeste de charge de travail multi-NIC non TPU.
apiVersion: v1
kind: Pod
metadata:
name: sample-netdevice-pod-1
annotations:
networking.gke.io/default-interface: 'eth0'
networking.gke.io/interfaces: |
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"netdevice-network"}
]
spec:
containers:
- name: sample-netdevice-pod
image: busybox
command: ["sleep", "infinity"]
ports:
- containerPort: 80
restartPolicy: Always
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
Si vous exec
dans le pod Kubernetes, vous devriez voir la NIC supplémentaire à l'aide du code suivant.
$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
inet 172.24.0.4/32 scope global eth1
valid_lft forever preferred_lft forever
Configurer JAX à l'aide de GKE avec XPK
Pour en savoir plus, consultez le fichier README xpk.
Pour configurer et exécuter XPK avec MaxText, consultez Exécuter MaxText.
Configurer JAX à l'aide de ressources en file d'attente
Installez JAX sur toutes les VM TPU de votre ou vos tranches simultanément à l'aide de gcloud alpha compute tpus tpu-vm ssh
. Pour Multislice, ajoutez --node=all
.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'
Vous pouvez exécuter le code Python suivant pour vérifier le nombre de cœurs TPU disponibles dans votre tranche et pour vérifier que tout est correctement installé (les sorties affichées ici ont été produites avec une tranche v6e-16):
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
Le résultat ressemble à ce qui suit :
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count() indique le nombre total de puces dans la tranche donnée. jax.local_device_count() indique le nombre de puces accessibles par une seule VM dans cette tranche.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
&& pip install -r requirements.txt && pip install . '
Résoudre les problèmes de configuration JAX
En règle générale, nous vous recommandons d'activer la journalisation détaillée dans le fichier manifeste de votre charge de travail GKE. Fournissez ensuite les journaux à l'assistance GKE.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Messages d'erreur
no endpoints available for service 'jobset-webhook-service'
Cette erreur signifie que le jobset n'a pas été installé correctement. Vérifiez si les pods Kubernetes de déploiement jobset-controller-manager sont en cours d'exécution. Pour en savoir plus, consultez la documentation de dépannage JobSet.
TPU initialization failed: Failed to connect
Assurez-vous que votre nœud GKE utilise la version 1.30.4-gke.1348000 ou une version ultérieure (GKE 1.31 n'est pas compatible).
Configuration pour PyTorch
Cette section explique comment commencer à utiliser PJRT sur la version v6e avec PyTorch/XLA. La version recommandée est Python 3.10.
Configurer PyTorch à l'aide de GKE avec XPK
Vous pouvez utiliser le conteneur Docker suivant avec XPK, qui contient déjà les dépendances PyTorch:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Pour créer une charge de travail XPK, utilisez la commande suivante:
python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES} \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
L'utilisation de --base-docker-image
crée une nouvelle image Docker avec le répertoire de travail actuel intégré au nouveau Docker.
Configurer PyTorch à l'aide de ressources en file d'attente
Suivez ces étapes pour installer PyTorch à l'aide de ressources mises en file d'attente et exécuter un petit script sur la version 6e.
Installer des dépendances à l'aide de SSH pour accéder aux VM
Pour Multislice, ajoutez --node=all
:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt install -y libopenblas-base pip3 \
install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Améliorer les performances des modèles avec des allocations importantes et fréquentes
Pour les modèles comportant des allocations fréquentes et importantes, nous avons constaté que l'utilisation de tcmalloc
améliore considérablement les performances par rapport à l'implémentation par défaut de malloc
. Par conséquent, la valeur par défaut de malloc
utilisée sur la VM TPU est tcmalloc
. Toutefois, en fonction de votre charge de travail (par exemple, avec DLRM qui dispose d'allocations très importantes pour ses tables d'imbrication), tcmalloc
peut entraîner un ralentissement. Dans ce cas, vous pouvez essayer de réinitialiser la variable suivante à l'aide de malloc
par défaut:
unset LD_PRELOAD
Utilisez un script Python pour effectuer un calcul sur la VM v6e:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--project ${PROJECT_ID} \
--zone ${ZONE} --worker all --command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'
Un résultat semblable à celui-ci doit s'afficher :
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
Configuration pour TensorFlow
Pour la version Preview publique de la version 6e, seule la version d'environnement d'exécution tf-nightly est prise en charge.
Vous pouvez réinitialiser tpu-runtime
avec la version TensorFlow compatible avec la version v6e en exécutant les commandes suivantes:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Utilisez SSH pour accéder à worker-0:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE}
Installez TensorFlow sur worker-0:
sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Exportez la variable d'environnement TPU_NAME
:
export TPU_NAME=v6e-16
Vous pouvez exécuter le script Python suivant pour vérifier le nombre de cœurs TPU disponibles dans votre tranche et pour vérifier que tout est correctement installé (les sorties affichées ont été produites avec une tranche v6e-16):
import TensorFlow as tf
print("TensorFlow version " + tf.__version__)
@tf.function
def add_fn(x,y):
z = x + y
return z
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
x = tf.constant(1.)
y = tf.constant(1.)
z = strategy.run(add_fn, args=(x,y))
print(z)
Le résultat ressemble à ce qui suit :
PerReplica:{
0: tf.Tensor(2.0, shape=(), dtype=float32),
1: tf.Tensor(2.0, shape=(), dtype=float32),
2: tf.Tensor(2.0, shape=(), dtype=float32),
3: tf.Tensor(2.0, shape=(), dtype=float32),
4: tf.Tensor(2.0, shape=(), dtype=float32),
5: tf.Tensor(2.0, shape=(), dtype=float32),
6: tf.Tensor(2.0, shape=(), dtype=float32),
7: tf.Tensor(2.0, shape=(), dtype=float32)
}
v6e avec SkyPilot
Vous pouvez utiliser TPU v6e avec SkyPilot. Suivez les étapes ci-dessous pour ajouter des informations sur la localisation/la tarification liées à la v6e à SkyPilot.
Ajoutez les éléments suivants à la fin de
~/.sky/catalogs/v5/gcp/vms.csv
:,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
Spécifiez les ressources suivantes dans un fichier YAML:
# tpu_v6.yaml resources: accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use accelerator_args: runtime_version: v2-alpha-tpuv6e # Official suggested runtime
Lancez un cluster avec TPU v6e:
sky launch tpu_v6.yaml -c tpu_v6
Connectez-vous au TPU v6e à l'aide de SSH:
ssh tpu_v6
Tutoriels sur l'inférence
Les tutoriels suivants montrent comment exécuter une inférence sur un TPU v6e:
Exemples d'entraînement
Les sections suivantes fournissent des exemples d'entraînement des modèles MaxText, MaxDiffusion et PyTorch sur TPU v6e.
Entraînement MaxText et MaxDiffusion sur une VM Cloud TPU v6e
Les sections suivantes couvrent le cycle de vie de l'entraînement des modèles MaxText et MaxDiffusion.
En général, les étapes de haut niveau sont les suivantes:
- Créez l'image de base de la charge de travail.
- Exécutez votre charge de travail à l'aide de XPK.
- Créez la commande d'entraînement pour la charge de travail.
- Déployez la charge de travail.
- Suivez la charge de travail et affichez les métriques.
- Supprimez la charge de travail XPK si elle n'est pas nécessaire.
- Supprimez le cluster XPK lorsqu'il n'est plus nécessaire.
Créer une image de base
Installez MaxText ou MaxDiffusion, puis créez l'image Docker:
Clonez le dépôt que vous souhaitez utiliser et accédez au répertoire du dépôt:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
Configurez Docker pour qu'il utilise la Google Cloud CLI:
gcloud auth configure-docker
Créez l'image Docker à l'aide de la commande suivante ou à l'aide de la pile stable JAX. Pour en savoir plus sur la pile stable JAX, consultez Créer une image Docker avec la pile stable JAX.
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.37
Si vous lancez la charge de travail à partir d'une machine sur laquelle l'image n'est pas compilée localement, importez-la:
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
Créer une image Docker avec la pile stable JAX
Vous pouvez créer les images Docker MaxText et MaxDiffusion à l'aide de l'image de base de la pile stable JAX.
La pile stable JAX fournit un environnement cohérent pour MaxText et MaxDiffusion en regroupant JAX avec des packages de base tels que orbax
, flax
et optax
, ainsi qu'une libtpu.so bien qualifiée qui gère les utilitaires de programme TPU et d'autres outils essentiels. Ces bibliothèques sont testées pour garantir la compatibilité et fournir une base stable pour créer et exécuter MaxText et MaxDiffusion. Cela élimine les conflits potentiels dus à des versions de paquets incompatibles.
La pile stable JAX inclut une libtpu.so entièrement publiée et qualifiée, la bibliothèque principale qui gère la compilation, l'exécution et la configuration réseau ICI des programmes TPU. La version libtpu remplace le build quotidien précédemment utilisé par JAX et garantit la fonctionnalité cohérente des calculs XLA sur le TPU avec des tests de qualification au niveau PJRT dans les IR HLO/StableHLO.
Pour créer l'image Docker MaxText et MaxDiffusion avec la pile stable JAX, lorsque vous exécutez le script docker_build_dependency_image.sh
, définissez la variable MODE
sur stable_stack
et la variable BASEIMAGE
sur l'image de base que vous souhaitez utiliser.
L'exemple suivant spécifie us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
comme image de base:
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
Pour obtenir la liste des images de base de la pile stable JAX disponibles, consultez la section Images de la pile stable JAX dans Artifact Registry.
Exécuter votre charge de travail à l'aide de XPK
Définissez les variables d'environnement suivantes si vous n'utilisez pas les valeurs par défaut définies par MaxText ou MaxDiffusion:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Créez le script de votre modèle. Ce script sera copié en tant que commande d'entraînement à une étape ultérieure.
N'exécutez pas encore le script du modèle.
MaxText
MaxText est un LLM Open Source hautes performances et hautement évolutif, écrit en Python et JAX purs, et ciblant les Google Cloud TPU et GPU pour l'entraînement et l'inférence.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma est une famille de LLM à poids ouverts développés par Google DeepMind, basés sur la recherche et la technologie Gemini.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral est un modèle d'IA de pointe développé par Mistral AI, qui utilise une architecture MoE (Mixture of Experts) sporadique.
python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama est une famille de LLM à pondération ouverte développée par Meta.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=llama3-8b \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ tokenizer_path=assets/tokenizer_llama3.tiktoken \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ attention=flash"
MaxDiffusion
MaxDiffusion est un ensemble d'implémentations de référence de divers modèles de diffusion latente écrits en Python et JAX purs, qui s'exécutent sur des appareils XLA, y compris des Cloud TPU et des GPU. Stable Diffusion est un modèle de texte vers image latent qui génère des images photoréalistes à partir de n'importe quelle entrée textuelle.
Vous devez installer une branche Git spécifique pour exécuter MaxDiffusion, comme indiqué dans la commande
git checkout
suivante.git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 && pip install -r requirements.txt && pip install .
Script d'entraînement:
cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \ python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash run_name=sdxl-ddp-v6e
Exécutez le modèle à l'aide du script que vous avez créé à l'étape précédente. Vous devez spécifier l'option
--base-docker-image
pour utiliser l'image de base MaxText ou l'option--docker-image
et l'image que vous souhaitez utiliser.Facultatif: vous pouvez activer la journalisation de débogage en incluant l'option
--enable-debug-logs
. Pour en savoir plus, consultez Déboguer JAX sur MaxText.Facultatif: vous pouvez créer un Vertex AI Experiment pour importer des données dans Vertex AI TensorBoard en incluant l'indicateur
--use-vertex-tensorboard
. Pour en savoir plus, consultez Surveiller JAX sur MaxText à l'aide de Vertex AI.python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=$ACCELERATOR_TYPE \ --num-slices=$NUM_SLICES \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command $YOUR-MODEL-SCRIPT
Exportez les variables suivantes:
export ClUSTER_NAME=CLUSTER_NAME: nom de votre cluster XPK. export ACCELERATOR_TYPEACCELERATOR_TYPE: version et taille de votre TPU. Par exemple,
v6e-256
. export NUM_SLICES=NUM_SLICES: nombre de tranches TPU. export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: script du modèle à exécuter en tant que commande d'entraînement.Le résultat inclut un lien permettant de suivre votre charge de travail, semblable à celui-ci:
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
Ouvrez le lien, puis cliquez sur l'onglet Journaux pour suivre votre charge de travail en temps réel.
Déboguer JAX sur MaxText
Utilisez des commandes XPK supplémentaires pour diagnostiquer pourquoi le cluster ou la charge de travail ne s'exécute pas.
- Liste des charges de travail XPK
- Inspecteur XPK
- Activez la journalisation détaillée dans vos journaux de charge de travail à l'aide de l'option
--enable-debug-logs
lorsque vous créez la charge de travail XPK.
Surveiller JAX sur MaxText à l'aide de Vertex AI
Afficher les données scalaires et de profil via TensorBoard géré par Vertex AI.
- Augmentez le nombre de requêtes de gestion des ressources (CRUD) pour la zone que vous utilisez de 600 à 5 000. Cela peut ne pas poser de problème pour les petites charges de travail utilisant moins de 16 VM.
Installez des dépendances telles que
cloud-accelerator-diagnostics
pour Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Créez votre cluster XPK à l'aide de l'option
--create-vertex-tensorboard
, comme décrit dans Créer Vertex AI TensorBoard. Vous pouvez également exécuter cette commande sur des clusters existants.Créez votre test Vertex AI lorsque vous exécutez votre charge de travail XPK à l'aide de l'indicateur
--use-vertex-tensorboard
et de l'indicateur--experiment-name
facultatif. Pour obtenir la liste complète des étapes, consultez Créer un test Vertex AI pour importer des données dans Vertex AI TensorBoard.
Les journaux incluent un lien vers un Vertex AI TensorBoard, semblable à celui-ci:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
Vous pouvez également trouver le lien Vertex AI TensorBoard dans la console Google Cloud. Accédez à Tests Vertex AI dans la console Google Cloud. Sélectionnez la région appropriée dans le menu déroulant.
Le répertoire TensorBoard est également écrit dans le bucket Cloud Storage que vous avez spécifié avec ${BASE_OUTPUT_DIR}
.
Supprimer des charges de travail XPK
Utilisez la commande xpk workload delete
pour supprimer une ou plusieurs charges de travail en fonction du préfixe ou de l'état de la tâche. Cette commande peut être utile si vous avez envoyé des charges de travail XPK qui n'ont plus besoin d'être exécutées ou si des tâches sont bloquées dans la file d'attente.
Supprimer le cluster XPK
Utilisez la commande xpk cluster delete
pour supprimer un cluster:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone $ZONE --project $PROJECT_ID
Entraînement Llama et PyTorch/XLA sur une VM Cloud TPU v6e
Ce tutoriel explique comment entraîner des modèles Llama à l'aide de PyTorch/XLA sur TPU v6e à l'aide de l'ensemble de données WikiText.
Accéder à Hugging Face et au modèle Llama 3
Vous avez besoin d'un jeton d'accès utilisateur Hugging Face pour suivre ce tutoriel. Pour en savoir plus sur la création et l'utilisation de jetons d'accès utilisateur, consultez la documentation de Hugging Face sur les jetons d'accès utilisateur.
Vous devez également disposer d'une autorisation pour accéder au modèle Llama 3 8B sur Hugging Face. Pour y accéder, accédez au modèle Meta-Llama-3-8B sur Hugging Face et demandez l'accès.
Créer une VM TPU
Créez un TPU v6e avec huit puces pour exécuter le tutoriel.
Configurez des variables d'environnement :
export ACCELERATOR_TYPE=v6e-8 export VERSION=v2-alpha-tpuv6e export TPU_NAME=$USER-$ACCELERATOR_TYPE export PROJECT=YOUR_PROJECT export ZONE=YOUR_ZONE
Créez une VM TPU:
gcloud alpha compute tpus tpu-vm create $TPU_NAME --version=$VERSION \ --accelerator-type=$ACCELERATOR_TYPE --zone=$ZONE --project=$PROJECT
Installation
Installez le fork pytorch-tpu/transformers
de Hugging Face Transformers et ses dépendances. Ce tutoriel a été testé avec les versions de dépendance suivantes utilisées dans cet exemple:
torch
: compatible avec la version 2.5.0torch_xla[tpu]
: compatible avec la version 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT --zone $ZONE \ --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Configurer les configurations de modèle
La commande d'entraînement de la section suivante, Exécuter le modèle, utilise deux fichiers de configuration JSON pour définir les paramètres du modèle et la configuration FSDP (Fully Sharded Data Parallel). Le fractionnement FSDP est utilisé pour que les poids du modèle s'adaptent à une taille de lot plus importante lors de l'entraînement. Lors de l'entraînement avec des modèles plus petits, il peut suffire d'utiliser le parallélisme des données et de répliquer les poids sur chaque appareil. Pour en savoir plus sur le fractionnement des tenseurs sur plusieurs appareils dans PyTorch/XLA, consultez le guide de l'utilisateur de SPMD PyTorch/XLA.
Créez le fichier de configuration des paramètres du modèle. Voici la configuration des paramètres du modèle pour Llama3-8B. Pour les autres modèles, recherchez la configuration sur Hugging Face. Consultez par exemple la configuration Llama2-7B.
cat > llama-config.json <
{ "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF Créez le fichier de configuration FSDP:
cat > fsdp-config.json <
{ "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF Pour en savoir plus sur le FSDP, consultez FSDPv2.
Importez les fichiers de configuration dans vos VM TPU à l'aide de la commande suivante:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \ --worker=all \ --project=$PROJECT \ --zone $ZONE
Exécuter le modèle
À l'aide des fichiers de configuration que vous avez créés dans la section précédente, exécutez le script run_clm.py
pour entraîner le modèle Llama 3 8B sur l'ensemble de données WikiText. L'exécution du script d'entraînement sur un TPU v6e-8 prend environ 10 minutes.
Connectez-vous à Hugging Face sur votre TPU à l'aide de la commande suivante:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Exécutez l'entraînement du modèle:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Résoudre les problèmes liés à PyTorch/XLA
Si vous définissez les variables facultatives pour le débogage dans la section précédente, le profil du modèle sera stocké à l'emplacement spécifié par la variable PROFILE_LOGDIR
. Vous pouvez extraire le fichier xplane.pb
stocké à cet emplacement et utiliser tensorboard
pour afficher les profils dans votre navigateur à l'aide des instructions TensorBoard. Si PyTorch/XLA ne fonctionne pas comme prévu, consultez le guide de dépannage, qui contient des suggestions pour déboguer, profiler et optimiser votre modèle.
Entraînement DLRM DCN v2 sur v6e
Ce tutoriel vous explique comment entraîner le modèle DLRM DCN v2 sur TPU v6e. Vous devez provisionner un TPU v6e avec 64, 128 ou 256 puces.
Si vous exécutez sur plusieurs hôtes, réinitialisez tpu-runtime
avec la version TensorFlow appropriée en exécutant la commande suivante. Si vous exécutez le programme sur un seul hôte, vous n'avez pas besoin d'exécuter les deux commandes suivantes.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID}
--zone ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Se connecter en SSH à worker-0
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}
Définir le nom du TPU
export TPU_NAME=${TPU_NAME}
Exécuter DLRM v2
pip install --user setuptools==65.5.0
pip install cloud-tpu-client
pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'
TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py --mode=train --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'mixed_bfloat16'
task:
use_synthetic_data: false
use_tf_record_reader: true
train_data:
input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
global_batch_size: 16384
use_cached_data: true
validation_data:
input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
global_batch_size: 16384
use_cached_data: true
model:
num_dense_features: 13
bottom_mlp: [512, 256, 128]
embedding_dim: 128
interaction: 'multi_layer_dcn'
dcn_num_layers: 3
dcn_low_rank_dim: 512
size_threshold: 8000
top_mlp: [1024, 1024, 512, 256, 1]
use_multi_hot: true
concat_dense: false
dcn_use_bias: true
vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
max_ids_per_chip_per_sample: 128
max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
use_partial_tpu_embedding: false
size_threshold: 0
initialize_tables_on_host: true
trainer:
train_steps: 10000
validation_interval: 1000
validation_steps: 660
summary_interval: 1000
steps_per_loop: 1000
checkpoint_interval: 0
optimizer_config:
embedding_optimizer: 'Adagrad'
dense_optimizer: 'Adagrad'
lr_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.025
warmup_steps: 0
dense_sgd_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.00025
warmup_steps: 8000
train_tf_function: true
train_tf_while_loop: true
eval_tf_while_loop: true
use_orbit: true
pipeline_sparse_and_dense_execution: true"
Exécutez script.sh
:
chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Les options suivantes sont nécessaires pour exécuter des charges de travail de recommandation (DCN DLRM):
ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"
Résultats du benchmark
La section suivante contient les résultats des analyses comparatives pour le DLRM DCN v2 et MaxDiffusion sur la version 6e.
DLRM DCN v2
Le script d'entraînement DLRM DCN v2 a été exécuté à différentes échelles. Consultez les débits dans le tableau suivant.
v6e-64 | v6e-128 | v6e-256 | |
Étapes de l'entraînement | 7000 | 7000 | 7000 |
Taille du lot global | 131072 | 262144 | 524288 |
Débit (exemples/s) | 2975334 | 5111808 | 10066329 |
MaxDiffusion
Nous avons exécuté le script d'entraînement pour MaxDiffusion sur un v6e-4, un v6e-16 et un 2xv6e-16. Consultez les débits dans le tableau suivant.
v6e-4 | v6e-16 | Deux v6e-16 | |
Étapes de l'entraînement | 0,069 | 0.073 | 0,13 |
Taille du lot global | 8 | 32 | 64 |
Débit (exemples/s) | 115,9 | 438,4 | 492,3 |
Planification de la collecte
Trillium (v6e) inclut une nouvelle fonctionnalité appelée "Planification de la collecte". Cette fonctionnalité permet de gérer plusieurs tranches de TPU exécutant une charge de travail d'inférence à hôte unique à la fois dans GKE et dans l'API Cloud TPU. Le regroupement de ces tranches dans une collection permet d'ajuster facilement le nombre de réplicas en fonction de la demande. Les mises à jour logicielles sont soigneusement contrôlées pour s'assurer qu'une partie des tranches de la collection est toujours disponible pour gérer le trafic entrant.
Pour en savoir plus sur l'utilisation de la planification de la collecte avec GKE, consultez la documentation GKE.
La fonctionnalité de planification de la collecte ne s'applique qu'à la version 6e.
Utiliser la planification de la collecte à partir de l'API Cloud TPU
Une collection à hôte unique dans l'API Cloud TPU est une ressource mise en file d'attente sur laquelle un indicateur spécial (--workload-type = availability-optimized
) est défini pour indiquer à l'infrastructure sous-jacente qu'elle est destinée à servir des charges de travail.
La commande suivante provisionne une collection à hôte unique à l'aide de l'API Cloud TPU:
gcloud alpha compute tpus queued-resources create my-collection \ --project=$PROJECT_ID \ --zone=${ZONE} \ --accelerator-type $ACCELERATOR_TYPE \ --node-count ${NODE_COUNT} \ --workload-type=availability-optimized
Surveiller et profiler
Cloud TPU v6e prend en charge la surveillance et le profilage à l'aide des mêmes méthodes que les générations précédentes de Cloud TPU. Pour en savoir plus sur la surveillance, consultez la page Surveiller les VM TPU.