Introducción a Trillium (v6e)
En esta documentación, la API de TPU y los registros, v6e se usa para referirse a Trillium. v6e representa la 6ª generación de TPU de Google.
Con 256 chips por Pod, la arquitectura v6e comparte muchas similitudes con la v5e. Este sistema está optimizado para el entrenamiento, la optimización y la entrega de transformadores, texto a imagen y redes neuronales convolucionales (CNN).
Para obtener más información sobre la arquitectura y las configuraciones del sistema v6e, consulta TPU v6e.
En este documento de introducción, se enfocan los procesos de entrenamiento y entrega de modelos con los frameworks de JAX, PyTorch o TensorFlow. Con cada framework, puedes provisionar TPU con recursos en cola o GKE. La configuración de GKE se puede realizar con XPK o comandos de GKE.
Procedimiento general para entrenar o entregar un modelo con la versión v6e
- Prepara un Google Cloud proyecto
- Capacidad segura
- Aprovisiona el entorno de Cloud TPU
- Ejecuta una carga de trabajo de entrenamiento o inferencia de modelos
Prepara un proyecto de Google Cloud
Antes de poder usar Cloud TPU, debes hacer lo siguiente:
- Crea una Google Cloud cuenta y un proyecto con la facturación habilitada
- Instala los componentes alfa de Google Cloud CLI
- Habilita la API de Cloud TPU
- Crea un agente de servicio de Cloud TPU
- Crea una cuenta de servicio de Cloud TPU y otorga permisos
Para obtener más información, consulta Configura el entorno de Cloud TPU.
Cómo proteger la capacidad
Comunícate con el equipo de asistencia deGoogle Cloud para solicitar una cuota de Cloud TPU v6e y responder cualquier pregunta sobre la capacidad.
Aprovisiona el entorno de Cloud TPU
Las Cloud TPU v6e se pueden aprovisionar y administrar con GKE, con GKE y XPK (una herramienta de wrapper de CLI sobre GKE) o como recursos en cola.
Requisitos previos
- Verifica que tu proyecto tenga suficiente cuota de
TPUS_PER_TPU_FAMILY
, que especifica la cantidad máxima de chips a los que puedes acceder en tu proyecto de Google Cloud. - La versión 6e se probó con la siguiente configuración:
- Python
3.10
o versiones posteriores - Versiones de software nocturnas:
- JAX nocturno
0.4.32.dev20240912
- LibTPU nocturna
0.1.dev20240912+nightly
- JAX nocturno
- Versiones de software estables:
- JAX + JAX Lib de la versión 0.4.37
- Python
Verifica que tu proyecto tenga suficiente cuota para lo siguiente:
- Cuota de VM de Cloud TPU
- Cuota de direcciones IP
Quota de Hyperdisk Balanced
Si usas GKE con XPK, consulta Permisos de la consola de Cloud en la cuenta de usuario o de servicio para conocer los permisos necesarios para ejecutar XPK.
Crea variables de entorno
En Cloud Shell, crea las siguientes variables de entorno:
export NODE_ID=your-tpu-name export PROJECT_ID=your-project-id export ACCELERATOR_TYPE=v6e-16 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=your-queued-resource-id export VALID_DURATION=your-duration # Additional environment variable needed for provisioning Multislice: export NUM_SLICES=number-of-slices # Use a custom network for better performance as well as to avoid having the default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Descripciones de las marcas de comandos
Variable | Descripción |
NODE_ID | Es el ID asignado por el usuario de la Cloud TPU que se crea cuando se asigna la solicitud de recurso en cola. |
ID DEL PROYECTO | Google Cloud nombre del proyecto. Usa un proyecto existente o crea uno nuevo. Para obtener más información, consulta Cómo configurar tu Google Cloud proyecto. |
ZONA | Consulta el documento Regiones y zonas de Cloud TPU para conocer las zonas compatibles. |
ACCELERATOR_TYPE | Consulta Tipos de aceleradores. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Esta es la dirección de correo electrónico de tu cuenta de servicio que puedes encontrar en
Google Cloud Consola -> IAM -> Cuentas de servicio
Por ejemplo: tpu-service-account@ |
NUM_SLICES | Es la cantidad de rebanadas que se deben crear (solo es necesario para Multislice). |
QUEUED_RESOURCE_ID | El ID de texto asignado por el usuario de la solicitud de recursos en cola. |
VALID_DURATION | Es la duración durante la cual la solicitud de recursos en cola es válida. |
NETWORK_NAME | Es el nombre de una red secundaria que se usará. |
NETWORK_FW_NAME | Es el nombre de un firewall de red secundario que se usará. |
Optimiza el rendimiento de la red
Para obtener el mejor rendimiento, usa una red con 8,896 MTU (unidad de transmisión máxima).
De forma predeterminada, una nube privada virtual (VPC) solo proporciona una MTU de 1,460 bytes, lo que proporcionará un rendimiento de red poco óptimo. Puedes configurar la MTU de una red de VPC en cualquier valor entre 1,300 bytes y 8,896 bytes (inclusive). Los tamaños de MTU personalizados comunes son 1,500 bytes (Ethernet estándar) o 8,896 bytes (el máximo posible). Para obtener más información, consulta Tamaños válidos de MTU de la red de VPC.
Para obtener más información sobre cómo cambiar la configuración de MTU de una red existente o predeterminada, consulta Cambia la configuración de MTU de una red de VPC.
En el siguiente ejemplo, se crea una red con 8,896 MTU.
export RESOURCE_NAME=your-resource-name export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \ --allow tcp,icmp,udp --project=${PROJECT_ID}
Cómo usar varias NIC (opción para Multislice)
Las siguientes variables de entorno son necesarias para una subred secundaria cuando usas un entorno de Multislice.
export NETWORK_NAME_2=${RESOURCE_NAME} export SUBNET_NAME_2=${RESOURCE_NAME} export FIREWALL_RULE_NAME=${RESOURCE_NAME} export ROUTER_NAME=${RESOURCE_NAME}-network-2 export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2 export REGION=your-region
Usa los siguientes comandos para crear enrutamiento de IP personalizado para la red y la subred.
gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
--bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
--network=${NETWORK_NAME_2} \
--range=10.10.0.0/18 --region=${REGION} \
--project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
--network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
--project=${PROJECT_ID} \
--network=${NETWORK_NAME_2} \
--region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
--router=${ROUTER_NAME} \
--region=${REGION} \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project=${PROJECT_ID} \
--enable-logging
Después de crear una porción de varias redes, puedes validar que se estén usando ambas tarjetas de interfaz de red (NIC) configurando un clúster de XPK y agregando la marca --command ifconfig
al comando de creación de cargas de trabajo de XPK.
Usa el siguiente comando xpk workload
para mostrar el resultado del comando ifconfig
en los registros de la consola de Google Cloud y verifica que tanto eth0 como eth1 tengan mtu=8896.
python3 xpk.py workload create \ --cluster your-cluster-name \ (--base-docker-image maxtext_base_image|--docker-image your-cloud-image-name \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command "ifconfig"
Verifica que tanto eth0 como eth1 tengan mtu=8,896. Para verificar que se esté ejecutando la NIC múltiple, agrega la marca --command ifconfig
al comando de creación de cargas de trabajo de XPK. Verifica el resultado de esa carga de trabajo de xpk en los registros de la consola de Google Cloud y comprueba que tanto eth0 como eth1 tengan mtu=8896.
Mejora la configuración de TCP
Si creaste tus Cloud TPU con la interfaz de recursos en cola, puedes ejecutar el siguiente comando para mejorar el rendimiento de la red aumentando los límites del búfer de recepción de TCP.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "${PROJECT}" \ --zone "${ZONE}" \ --node=all \ --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \ --worker=all
Aprovisiona con recursos en cola
Puedes crear un Cloud TPU v6e con recursos en cola. Los recursos en fila te permiten recibir capacidad una vez que esté disponible. Puedes especificar una hora de inicio y finalización opcionales para el momento en que se debe completar la solicitud. Para obtener más información, consulta Administra recursos en cola.
Aprovisiona Cloud TPU v6e con GKE o XPK
Si usas comandos de GKE con la versión 6e, puedes usar comandos de Kubernetes o XPK para aprovisionar Cloud TPU y entrenar o entregar modelos. Consulta Planifica Cloud TPU en GKE para obtener información sobre cómo planificar tus configuraciones de Cloud TPU en clústeres de GKE. En las siguientes secciones, se proporcionan comandos para crear un clúster de XPK con compatibilidad con una sola NIC y con varias NIC.
Crea un clúster de XPK con compatibilidad con una sola NIC
export CLUSTER_NAME=xpk-cluster-name export ZONE=us-central2-b export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network=${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=n1-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments=${CLUSTER_ARGUMENTS} \ --create-vertex-tensorboard
Descripciones de las marcas de comandos
Variable | Descripción |
CLUSTER_NAME | Es el nombre asignado por el usuario para el clúster de XPK. |
ID DEL PROYECTO | Google Cloud nombre del proyecto. Usa un proyecto existente o crea uno nuevo. Para obtener más información, consulta Cómo configurar tu Google Cloud proyecto. |
ZONA | Consulta el documento Regiones y zonas de Cloud TPU para conocer las zonas compatibles. |
TPU_TYPE | Consulta Tipos de aceleradores. |
NUM_SLICES | La cantidad de divisiones que deseas crear |
CLUSTER_ARGUMENTS | La red y la subred que se usarán.
Por ejemplo: |
NUM_SLICES | Es la cantidad de rebanadas que se crearán. |
NETWORK_NAME | Es el nombre de una red secundaria que se usará. |
NETWORK_FW_NAME | Es el nombre de un firewall de red secundario que se usará. |
Crea un clúster de XPK con compatibilidad con varias NIC
export CLUSTER_NAME xpk-cluster-name export ZONE=us-central2-b export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export exportSUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \ --network=${NETWORK_NAME_1} \ --range=10.11.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_1} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_1} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster=${CLUSTER_NAME} \
--num-slices=${NUM_SLICES} \
--tpu-type=${TPU_TYPE} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--on-demand \
--custom-cluster-arguments=${CLUSTER_ARGUMENTS} \
--custom-nodepool-arguments=${NODE_POOL_ARGUMENTS} \
--create-vertex-tensorboard
Descripciones de las marcas de comandos
Variable | Descripción |
CLUSTER_NAME | Es el nombre asignado por el usuario para el clúster de XPK. |
ID DEL PROYECTO | Google Cloud nombre del proyecto. Usa un proyecto existente o crea uno nuevo. Para obtener más información, consulta Cómo configurar tu Google Cloud proyecto. |
ZONA | Consulta el documento Regiones y zonas de Cloud TPU para conocer las zonas compatibles. |
TPU_TYPE | Consulta Tipos de aceleradores. |
NUM_SLICES | La cantidad de divisiones que deseas crear |
CLUSTER_ARGUMENTS | La red y la subred que se usarán.
Por ejemplo: |
NODE_POOL_ARGUMENTS | Es la red de nodos adicional que se usará.
Por ejemplo: |
NUM_SLICES | Es la cantidad de rebanadas que se deben crear (solo es necesario para Multislice). |
NETWORK_NAME | Es el nombre de una red secundaria que se usará. |
NETWORK_FW_NAME | Es el nombre de un firewall de red secundario que se usará. |
Configuración del framework
En esta sección, se describe el proceso de configuración general para el entrenamiento de modelos de AA con los frameworks JAX, PyTorch o TensorFlow. Si usas GKE, puedes usar XPK o comandos de Kubernetes para configurar el framework.
Configuración para JAX
En esta sección, se proporcionan instrucciones de configuración para ejecutar cargas de trabajo de JAX en GKE, con o sin XPK, así como para usar recursos en cola.
Configura JAX con GKE
Porción única en un solo host
En el siguiente ejemplo, se configura un grupo de nodos de un solo host de 2 × 2 con un archivo YAML de Kubernetes.
apiVersion: v1
kind: Pod
metadata:
name: tpu-pod-jax-v6e-a
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
containers:
- name: tpu-job
image: python:3.10
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Cuando se complete correctamente, deberías ver el siguiente mensaje en el registro de GKE:
Total TPU chips: 4
Porción única en varios hosts
En el siguiente ejemplo, se configura un grupo de nodos multihost 4 × 4 con un archivo YAML de Kubernetes.
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-available-chips
spec:
backoffLimit: 0
completions: 4
parallelism: 4
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: tpu-job
image: python:3.10
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Cuando se complete correctamente, deberías ver el siguiente mensaje en el registro de GKE:
Total TPU chips: 16
Porciones múltiples en varios hosts
En el siguiente ejemplo, se configuran dos grupos de nodos multihost de 4 × 4 con un archivo YAML de Kubernetes.
Como requisito previo, debes instalar JobSet v0.2.3 o una versión posterior.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: multislice-job
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
maxRestarts: 4
replicatedJobs:
- name: slice
replicas: 2
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
hostNetwork: true
containers:
- name: jax-tpu
image: python:3.10
ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 8431
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
limits:
google.com/tpu: 4
requests:
google.com/tpu: 4
Cuando se complete correctamente, deberías ver el siguiente mensaje en el registro de GKE:
Total TPU chips: 32
Para obtener más información, consulta Cómo ejecutar una carga de trabajo de porciones múltiples en la documentación de GKE.
Para mejorar el rendimiento, habilita hostNetwork.
Multi-NIC
Para aprovechar la NIC múltiple en GKE, el manifiesto del pod de Kubernetes debe tener anotaciones adicionales. El siguiente es un manifiesto de ejemplo de carga de trabajo de varias NIC que no es de TPU.
apiVersion: v1
kind: Pod
metadata:
name: sample-netdevice-pod-1
annotations:
networking.gke.io/default-interface: 'eth0'
networking.gke.io/interfaces: |
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"netdevice-network"}
]
spec:
containers:
- name: sample-netdevice-pod
image: busybox
command: ["sleep", "infinity"]
ports:
- containerPort: 80
restartPolicy: Always
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
Si usas el comando exec
para conectarte al Pod de Kubernetes, deberías ver la NIC adicional con el siguiente código.
$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
inet 172.24.0.4/32 scope global eth1
valid_lft forever preferred_lft forever
Configura JAX con GKE y XPK
Para configurar JAX con GKE y XPK, consulta el archivo readme de xpk.
Para configurar y ejecutar XPK con MaxText, consulta Cómo ejecutar MaxText.
Configura JAX con recursos en cola
Instala JAX en todas las VMs de Cloud TPU de tu o tus slices de forma simultánea con el comando gcloud alpha compute tpus tpu-vm ssh
. Para Multislice, agrega la marca --node=all
.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='pip install -U --pre jax jaxlib libtpu-nightly requests
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Puedes ejecutar el siguiente comando para verificar cuántos núcleos de Cloud TPU están disponibles en tu fragmento y probar que todo esté instalado correctamente:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
El resultado es similar al siguiente cuando se ejecuta en una porción v6e-16:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count()
muestra la cantidad total de chips en la porción determinada.
jax.local_device_count()
indica la cantidad de chips a los que puede acceder una sola VM en esta porción.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
&& pip install -r requirements.txt && pip install . '
Cómo solucionar problemas de configuración de JAX
Una sugerencia general es habilitar el registro detallado en el manifiesto de tu carga de trabajo de GKE. Luego, proporciona los registros al equipo de asistencia de GKE.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Mensajes de error
no endpoints available for service 'jobset-webhook-service'
Este error significa que el conjunto de trabajos no se instaló correctamente. Verifica si se están ejecutando los pods de Kubernetes de la implementación de jobset-controller-manager. Para obtener más información, consulta la documentación de solución de problemas de JobSet.
TPU initialization failed: Failed to connect
Asegúrate de que la versión de tu nodo de GKE sea 1.30.4-gke.1348000 o posterior (no se admite GKE 1.31).
Configuración para PyTorch
En esta sección, se describe cómo comenzar a usar PJRT en v6e con PyTorch/XLA. La versión recomendada de Python es 3.10.
Configura PyTorch con GKE y XPK
Puedes usar el siguiente contenedor de Docker con XPK, que ya tiene instaladas las dependencias de PyTorch:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Para crear una carga de trabajo de XPK, usa el siguiente comando:
python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES} \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
El uso de --base-docker-image
crea una nueva imagen de Docker con el directorio de trabajo
actual integrado en el nuevo Docker.
Configura PyTorch con recursos en cola
Sigue estos pasos para instalar PyTorch con recursos en cola y ejecutar una pequeña secuencia de comandos en v6e.
Instala dependencias con SSH para acceder a las VMs
Usa el siguiente comando para instalar dependencias en todas las VMs de Cloud TPU. Para Multislice, agrega la marca --node=all
:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt install -y libopenblas-base pip3 \
install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Mejora el rendimiento de los modelos con asignaciones grandes y frecuentes
Para los modelos que tienen asignaciones frecuentes y de tamaño, usar la función tcmalloc
mejora el rendimiento de manera significativa en comparación con la implementación predeterminada de la función malloc
, por lo que la función malloc
predeterminada que se usa en la VM de Cloud TPU es tcmalloc
. Sin embargo, según tu carga de trabajo (por ejemplo, con DLRM, que tiene asignaciones muy grandes para sus tablas de incorporación), la función tcmalloc
puede causar una ralentización. En ese caso, puedes intentar restablecer la siguiente variable con la función malloc
predeterminada:
unset LD_PRELOAD
Usa una secuencia de comandos de Python para realizar un cálculo en la VM v6e
Usa el siguiente comando para ejecutar una secuencia de comandos que cree dos tensores, los agregue juntos y muestre el resultado.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--project ${PROJECT_ID} \
--zone ${ZONE} --worker all --command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'
Esto genera un resultado similar al que se muestra a continuación:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
Configuración para TensorFlow
Para restablecer el entorno de ejecución de Cloud TPU con la versión de TensorFlow compatible con v6e, ejecuta los siguientes comandos:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Usa SSH para acceder a worker-0:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE}
Instala TensorFlow en worker-0:
sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Exporta la variable de entorno TPU_NAME
:
export TPU_NAME=v6e-16
Puedes ejecutar la siguiente secuencia de comandos de Python para verificar cuántos núcleos de Cloud TPU están disponibles en tu fragmento y probar que todo esté instalado correctamente:
import TensorFlow as tf
print("TensorFlow version " + tf.__version__)
@tf.function
def add_fn(x,y):
z = x + y
return z
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
x = tf.constant(1.)
y = tf.constant(1.)
z = strategy.run(add_fn, args=(x,y))
print(z)
El resultado es similar al siguiente cuando se ejecuta en una porción v6e-16:
PerReplica:{
0: tf.Tensor(2.0, shape=(), dtype=float32),
1: tf.Tensor(2.0, shape=(), dtype=float32),
2: tf.Tensor(2.0, shape=(), dtype=float32),
3: tf.Tensor(2.0, shape=(), dtype=float32),
4: tf.Tensor(2.0, shape=(), dtype=float32),
5: tf.Tensor(2.0, shape=(), dtype=float32),
6: tf.Tensor(2.0, shape=(), dtype=float32),
7: tf.Tensor(2.0, shape=(), dtype=float32)
}
v6e con SkyPilot
Puedes usar Cloud TPU v6e con SkyPilot. Sigue estos pasos para agregar información de ubicación y precios relacionada con v6e a SkyPilot.
Agrega lo siguiente al final del archivo
~/.sky/catalogs/v5/gcp/vms.csv
:,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
Especifica los siguientes recursos en un archivo YAML:
# tpu_v6.yaml resources: accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use accelerator_args: runtime_version: v2-alpha-tpuv6e # Official suggested runtime
Inicia un clúster con Cloud TPU v6e:
sky launch tpu_v6.yaml -c tpu_v6
Conéctate a la Cloud TPU v6e con SSH:
ssh tpu_v6
Instructivos de inferencia
En los siguientes instructivos, se muestra cómo ejecutar inferencias en Cloud TPU v6e:
Ejemplos de entrenamiento
En las siguientes secciones, se proporcionan ejemplos para entrenar modelos de MaxText, MaxDiffusion y PyTorch en Cloud TPU v6e.
Entrenamiento de MaxText y MaxDiffusion en la VM de Cloud TPU v6e
En las siguientes secciones, se describe el ciclo de vida de entrenamiento de los modelos MaxText y MaxDiffusion.
En general, los pasos de alto nivel son los siguientes:
- Compila la imagen base de la carga de trabajo.
- Ejecuta tu carga de trabajo con XPK.
- Compila el comando de entrenamiento para la carga de trabajo.
- Implementa la carga de trabajo.
- Sigue la carga de trabajo y consulta las métricas.
- Borra la carga de trabajo de XPK si no la necesitas.
- Borra el clúster de XPK cuando ya no lo necesites.
Cómo compilar la imagen base
Instala MaxText o MaxDiffusion y compila la imagen de Docker:
Clona el repositorio que deseas usar y cambia al directorio del repositorio:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
Configura Docker para usar Google Cloud CLI:
gcloud auth configure-docker
Compila la imagen de Docker con el siguiente comando o con la pila estable de JAX. Para obtener más información sobre JAX Stable Stack, consulta Cómo compilar una imagen de Docker con JAX Stable Stack.
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.37
Si inicias la carga de trabajo desde una máquina que no tiene la imagen compilada de forma local, sube la imagen:
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
Compila una imagen de Docker con JAX Stable Stack
Puedes compilar las imágenes de Docker de MaxText y MaxDiffusion con la imagen base de la pila estable de JAX.
El paquete estable de JAX proporciona un entorno coherente para MaxText y MaxDiffusion, ya que agrupa JAX con paquetes principales, como orbax
, flax
y optax
, junto con una libtpu.so bien calificada que impulsa las utilidades del programa de Cloud TPU y otras herramientas esenciales. Estas bibliotecas se prueban para garantizar la compatibilidad y proporcionar
una base estable para compilar y ejecutar MaxText y MaxDiffusion. Esto elimina posibles conflictos debido a versiones de paquetes incompatibles.
La pila estable de JAX incluye una libtpu.so completamente publicada y calificada, la biblioteca principal que impulsa la compilación, la ejecución y la configuración de la red de ICI del programa de Cloud TPU. La versión de libtpu reemplaza la compilación nocturna que usaba JAX anteriormente y garantiza una funcionalidad coherente de los cálculos de XLA en Cloud TPU con pruebas de calificación a nivel de PJRT en IR de HLO/StableHLO.
Para compilar la imagen de Docker de MaxText y MaxDiffusion con la pila estable de JAX, cuando
ejecutes la secuencia de comandos docker_build_dependency_image.sh
, establece la variable MODE
en
stable_stack
y la variable BASEIMAGE
en la imagen base que deseas
usar.
En el siguiente ejemplo, se especifica us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
como la imagen base:
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
Para obtener una lista de las imágenes base de JAX Stable Stack disponibles, consulta Imágenes de JAX Stable Stack en Artifact Registry.
Ejecuta tu carga de trabajo con XPK
Establece las siguientes variables de entorno si no usas los valores predeterminados establecidos por MaxText o MaxDiffusion:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Compila la secuencia de comandos de tu modelo. Esta secuencia de comandos se copiará como un comando de entrenamiento en un paso posterior.
Aún no ejecutes la secuencia de comandos del modelo.
MaxText
MaxText es un LLM de código abierto, altamente escalable y de alto rendimiento, escrito en Python y JAX puros, y se orienta a Google Cloud TPU y GPUs para el entrenamiento y la inferencia.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma es una familia de LLM con pesos abiertos que desarrolló Google DeepMind, basada en la investigación y la tecnología de Gemini.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral es un modelo de IA de vanguardia desarrollado por Mistral AI que utiliza una arquitectura dispersa de mezcla de expertos (MoE).
python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama es una familia de LLM de ponderación abierta que desarrolló Meta.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=llama3-8b \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ tokenizer_path=assets/tokenizer_llama3.tiktoken \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ attention=flash
MaxDiffusion
MaxDiffusion es una colección de implementaciones de referencia de varios modelos de difusión latente escritos en Python y JAX puros que se ejecutan en dispositivos XLA, incluidas las GPUs y las Cloud TPU. Stable Diffusion es un modelo latente de texto a imagen que genera imágenes fotorrealistas a partir de cualquier entrada de texto.
Debes instalar una rama específica de Git para ejecutar MaxDiffusion, como se muestra en el siguiente comando
git checkout
.git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 && pip install -r requirements.txt && pip install .
Secuencia de comandos de entrenamiento:
cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \ python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash run_name=sdxl-ddp-v6e
Ejecuta el modelo con la secuencia de comandos que creaste en el paso anterior. Debes especificar la marca
--base-docker-image
para usar la imagen base de MaxText o especificar la marca--docker-image
y la imagen que deseas usar.Opcional: Puedes habilitar el registro de depuración si incluyes la marca
--enable-debug-logs
. Para obtener más información, consulta Cómo depurar JAX en MaxText.Opcional: Puedes crear un experimento de Vertex AI para subir datos a Vertex AI TensorBoard si incluyes la marca
--use-vertex-tensorboard
. Para obtener más información, consulta Supervisa JAX en MaxText con Vertex AI.python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command=$YOUR-MODEL-SCRIPT
Exporta las siguientes variables:
export CLUSTER_NAME=CLUSTER_NAME: The name of your XPK cluster. export ACCELERATOR_TYPEACCELERATOR_TYPE: The version and size of your TPU. For example, `v6e-256`. export NUM_SLICES=NUM_SLICES: The number of Cloud TPU slices. export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: The model script to execute as a training command.
El resultado incluye un vínculo para seguir tu carga de trabajo, similar al siguiente:
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
Abre el vínculo y haz clic en la pestaña Registros para hacer un seguimiento de tu carga de trabajo en tiempo real.
Cómo depurar JAX en MaxText
Usa comandos XPK complementarios para diagnosticar por qué no se ejecuta el clúster o la carga de trabajo.
- Lista de cargas de trabajo de XPK
- Inspector de XPK
- Habilita el registro detallado en tus registros de cargas de trabajo con la marca
--enable-debug-logs
cuando crees la carga de trabajo de XPK.
Supervisa JAX en MaxText con Vertex AI
Consulta los datos escalares y de perfil a través de TensorBoard administrado de Vertex AI.
- Aumenta las solicitudes de administración de recursos (CRUD) de la zona que usas de 600 a 5,000. Esto podría no ser un problema para cargas de trabajo pequeñas que usan menos de 16 VMs.
Instala dependencias como
cloud-accelerator-diagnostics
para Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Crea tu clúster de XPK con la marca
--create-vertex-tensorboard
, como se documenta en Crea Vertex AI TensorBoard. También puedes ejecutar este comando en clústeres existentes.Crea tu experimento de Vertex AI cuando ejecutes tu carga de trabajo de XPK con la marca
--use-vertex-tensorboard
y la marca opcional--experiment-name
. Para obtener la lista completa de pasos, consulta Crea un experimento de Vertex AI para subir datos a Vertex AI TensorBoard.
Los registros incluyen un vínculo a un Vertex AI TensorBoard, similar al siguiente:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
También puedes encontrar el vínculo de Vertex AI TensorBoard en la consola de Google Cloud. Ve a Vertex AI Experiments en la console de Google Cloud. Selecciona la región adecuada en el menú desplegable.
El directorio de TensorBoard también se escribe en el bucket de Cloud Storage que especificaste con ${BASE_OUTPUT_DIR}
.
Borra cargas de trabajo de XPK
Usa el comando xpk workload delete
para borrar una o más cargas de trabajo según el prefijo o el estado del trabajo. Este comando puede ser útil si enviaste cargas de trabajo de XPK que ya no necesitan ejecutarse o si tienes trabajos que están bloqueados en la cola.
Borra el clúster de XPK
Usa el comando xpk cluster delete
para borrar un clúster:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone=${ZONE} --project=${PROJECT_ID}
Entrenamiento de Llama y PyTorch/XLA en la VM de Cloud TPU v6e
En este instructivo, se describe cómo entrenar modelos de Llama con PyTorch/XLA en Cloud TPU v6e con el conjunto de datos WikiText.
Obtén acceso a Hugging Face y al modelo Llama 3
Necesitas un token de acceso de usuario de Hugging Face para ejecutar este instructivo. Para obtener información sobre cómo crear tokens de acceso de usuario, consulta la documentación de Hugging Face sobre los tokens de acceso de usuario.
También necesitas permiso para acceder al modelo Llama 3 8B en Hugging Face. Para obtener acceso, ve al modelo Meta-Llama-3-8B en HuggingFace y solicita acceso.
Crea una VM de Cloud TPU
Crea una Cloud TPU v6e con 8 chips para ejecutar el instructivo.
Configure las variables de entorno:
export ACCELERATOR_TYPE=v6e-8 export VERSION=v2-alpha-tpuv6e export TPU_NAME=$USER-$ACCELERATOR_TYPE export PROJECT_ID=your-project-id export ZONE=your-zone
Crea una VM de Cloud TPU:
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${VERSION} \ --accelerator-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Instalación
Instala la bifurcación pytorch-tpu/transformers
de los transformadores y las dependencias de Hugging Face. Este instructivo se probó con las siguientes versiones de dependencias que se usan en este ejemplo:
torch
: Compatible con 2.5.0torch_xla[tpu]
: Compatible con 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} --zone ${ZONE} \ --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Configura los parámetros de configuración del modelo
El comando de entrenamiento de la siguiente sección, Ejecuta el modelo, usa dos archivos de configuración JSON para definir los parámetros del modelo y la configuración de FSDP (paralelismo de datos completamente fragmentado). El fragmentación de FSDP se usa para que los pesos del modelo se ajusten a un tamaño de lote más grande durante el entrenamiento. Cuando entrenas con modelos más pequeños, podría ser suficiente usar el paralelismo de datos y replicar los pesos en cada dispositivo. Si deseas obtener más información para dividir tensores en varios dispositivos en PyTorch/XLA, consulta la Guía del usuario de SPMD de PyTorch/XLA.
Crea el archivo de configuración del parámetro del modelo. La siguiente es la configuración del parámetro del modelo para Llama3-8B. Para otros modelos, busca la configuración en Hugging Face. Por ejemplo, consulta la configuración de Llama2-7B.
cat > llama-config.json << EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
Crea el archivo de configuración de FSDP:
cat > fsdp-config.json << EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
Para obtener más información sobre FSDP, consulta FSDPv2.
Usa el siguiente comando para subir los archivos de configuración a tus VMs de Cloud TPU:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
Ejecuta el modelo
Con los archivos de configuración que creaste en la sección anterior, ejecuta la secuencia de comandos run_clm.py
para entrenar el modelo Llama 3 8B en el conjunto de datos de WikiText. La secuencia de comandos de entrenamiento tarda alrededor de 10 minutos en ejecutarse en una Cloud TPU v6e-8.
Accede a Hugging Face en tu Cloud TPU con el siguiente comando:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \ --zone ${ZONE} \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Ejecuta el entrenamiento del modelo:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project=${PROJECT} \ --zone ${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Soluciona problemas de PyTorch/XLA
Si configuras las variables opcionales para la depuración en la sección anterior, el perfil del modelo se almacenará en la ubicación que especifique la variable PROFILE_LOGDIR
. Puedes extraer el archivo xplane.pb
almacenado en esta ubicación y usar tensorboard
para ver los perfiles en tu navegador con las instrucciones de TensorBoard. Si PyTorch/XLA no funciona como se espera, consulta la guía de solución de problemas, que tiene sugerencias para depurar, generar perfiles y optimizar tu modelo.
Entrenamiento de DLRM DCN v2 en v6e
En este instructivo, se muestra cómo entrenar el modelo DLRM DCN v2 en Cloud TPU v6e. Debes aprovisionar una TPU v6e con 64, 128 o 256 chips.
Si ejecutas en una TPU de varios hosts, ejecuta los siguientes comandos para restablecer tpu-runtime
con la versión correcta de TensorFlow. Si ejecutas la función en una TPU de un solo host, no necesitas ejecutar los siguientes dos comandos.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID}
--zone ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Conéctate a worker-0 con SSH
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}
Configura el nombre de Cloud TPU
export TPU_NAME=${TPU_NAME}
Ejecuta DLRM v2
Copia el siguiente fragmento de código en un archivo llamado script.sh
:
pip install --user setuptools==65.5.0
pip install cloud-tpu-client
pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'
TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py --mode=train --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'mixed_bfloat16'
task:
use_synthetic_data: false
use_tf_record_reader: true
train_data:
input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
global_batch_size: 16384
use_cached_data: true
validation_data:
input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
global_batch_size: 16384
use_cached_data: true
model:
num_dense_features: 13
bottom_mlp: [512, 256, 128]
embedding_dim: 128
interaction: 'multi_layer_dcn'
dcn_num_layers: 3
dcn_low_rank_dim: 512
size_threshold: 8000
top_mlp: [1024, 1024, 512, 256, 1]
use_multi_hot: true
concat_dense: false
dcn_use_bias: true
vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
max_ids_per_chip_per_sample: 128
max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
use_partial_tpu_embedding: false
size_threshold: 0
initialize_tables_on_host: true
trainer:
train_steps: 10000
validation_interval: 1000
validation_steps: 660
summary_interval: 1000
steps_per_loop: 1000
checkpoint_interval: 0
optimizer_config:
embedding_optimizer: 'Adagrad'
dense_optimizer: 'Adagrad'
lr_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.025
warmup_steps: 0
dense_sgd_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.00025
warmup_steps: 8000
train_tf_function: true
train_tf_while_loop: true
eval_tf_while_loop: true
use_orbit: true
pipeline_sparse_and_dense_execution: true"
Si ejecutas TensorFlow en GKE, instala la rueda de Cloud TPU de TensorFlow y libtpu con el siguiente comando:
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Establece las siguientes marcas, que son necesarias para ejecutar cargas de trabajo de Recommendation (como DLRM DCN):
ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"
Ejecuta script.sh
:
chmod +x script.sh
./script.sh
Resultados de comparativas
En la siguiente sección, se incluyen los resultados de las comparativas de DLRM DCN v2 y MaxDiffusion en v6e.
DLRM DCN v2
La secuencia de comandos de entrenamiento de DLRM DCN v2 se ejecutó a diferentes escalas. Consulta las tasas de transferencia en la siguiente tabla.
v6e-64 | v6e-128 | v6e-256 | |
---|---|---|---|
Pasos de entrenamiento | 7000 | 7000 | 7000 |
Tamaño del lote global | 131072 | 262144 | 524288 |
Capacidad de procesamiento (ejemplos/s) | 2975334 | 5111808 | 10066329 |
MaxDiffusion
Ejecutamos la secuencia de comandos de entrenamiento de MaxDiffusion en una v6e-4, una v6e-16 y dos v6e-16. Consulta las tasas de transferencia en la siguiente tabla.
v6e-4 | v6e-16 | Dos v6e-16 | |
---|---|---|---|
Pasos de entrenamiento | 0.069 | 0.073 | 0.13 |
Tamaño del lote global | 8 | 32 | 64 |
Capacidad de procesamiento (ejemplos/s) | 115.9 | 438.4 | 492.3 |