Introducción a Trillium (v6e)
En esta documentación, la API de TPU y los registros, v6e se usa para referirse a Trillium. v6e representa la 6ª generación de TPU de Google.
Con 256 chips por Pod, la arquitectura v6e comparte muchas similitudes con la v5e. Este sistema está optimizado para el entrenamiento, la optimización y la entrega de transformadores, texto a imagen y redes neuronales convolucionales (CNN).
Consulta el documento v6e para obtener información sobre la arquitectura y las configuraciones del sistema v6e.
En este documento de introducción, se enfocan los procesos de entrenamiento y publicación de modelos con los frameworks de JAX, PyTorch o TensorFlow. Con cada framework, puedes aprovisionar TPU con recursos en cola o Google Kubernetes Engine (GKE). La configuración de GKE se puede realizar con XPK o comandos de GKE.
Procedimiento general para entrenar o entregar un modelo con la versión v6e
- Prepara un Google Cloud proyecto
- Capacidad segura
- Configura tu entorno de TPU
- Aprovisiona el entorno de Cloud TPU
- Ejecuta una carga de trabajo de entrenamiento o inferencia de modelos
- Realiza una limpieza
Prepara un Google Cloud proyecto
- Accede a tu Cuenta de Google. Si aún no lo hiciste, regístrate para obtener una nueva cuenta.
- En la consola de Google Cloud, selecciona o crea un proyecto de Cloud en la página del selector de proyectos.
- Habilita la facturación para tu proyecto de Google Cloud. La facturación es obligatoria para todo el uso de Google Cloud.
- Instala los componentes de gcloud alpha.
Ejecuta el siguiente comando para instalar la versión más reciente de los componentes de
gcloud
.gcloud components update
Habilita la API de TPU con el siguiente comando
gcloud
en Cloud Shell. También puedes habilitarlo desde la consola de Google Cloud.gcloud services enable tpu.googleapis.com
Habilita permisos con la cuenta de servicio de TPU para la API de Compute Engine
Las cuentas de servicio permiten que el servicio de Cloud TPU acceda a otros servicios de Google Cloud. Una cuenta de servicio administrada por el usuario es una práctica recomendada de Google Cloud. Sigue estas guías para crear y otorgar roles. Se requieren los siguientes roles:
- Administrador de TPU
- Administrador de almacenamiento
- Escritor de registros
- Escritor de métricas de Monitoring
a. Configura los permisos de XPK con tu cuenta de usuario de GKE: XPK.
Realiza la autenticación con tu Cuenta de Google y establece el ID y la zona del proyecto predeterminados.
auth login
autoriza a gcloud a acceder a Google Cloud con credenciales de usuario de Google.
PROJECT_ID
es el Google Cloud nombre del proyecto.
ZONE
es la zona en la que deseas crear la TPU.gcloud auth login gcloud config set project ${PROJECT_ID} gcloud config set compute/zone ${ZONE}
Crea una identidad de servicio para la VM de TPU.
gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
Cómo proteger la capacidad
Comunícate con el equipo de asistencia de ventas o de cuentas de Cloud TPU para solicitar la cuota de TPU y responder cualquier pregunta sobre la capacidad.
Aprovisiona el entorno de Cloud TPU
Las TPU v6e se pueden aprovisionar y administrar con GKE, con GKE y XPK (una herramienta de wrapper de CLI sobre GKE) o como recursos en cola.
Requisitos previos
- Verifica que tu proyecto tenga suficiente cuota de
TPUS_PER_TPU_FAMILY
, que especifica la cantidad máxima de chips a los que puedes acceder en tu proyecto deGoogle Cloud . - La versión 6e se probó con la siguiente configuración:
- Python
3.10
o versiones posteriores - Versiones de software nocturnas:
- JAX nocturno
0.4.32.dev20240912
- LibTPU nocturna
0.1.dev20240912+nightly
- JAX nocturno
- Versiones de software estables:
- JAX + JAX Lib de la versión 0.4.37
- Python
Verifica que tu proyecto tenga suficiente cuota de TPU para lo siguiente:
- Cuota de VM de TPU
- Cuota de direcciones IP
Quota de Hyperdisk Balanced
Permisos del proyecto del usuario
- Si usas GKE con XPK, consulta Permisos de la consola de Google Cloud en la cuenta de usuario o de servicio para conocer los permisos necesarios para ejecutar XPK.
Variables de entorno
En Cloud Shell, crea las siguientes variables de entorno:
export NODE_ID=TPU_NODE_ID # TPU name export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=v6e-16 export ZONE=us-east1-d export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=your-service-account export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID export VALID_DURATION=VALID_DURATION # Additional environment variable needed for provisioning Multislice: export NUM_SLICES=NUM_SLICES # Use a custom network for better performance as well as to avoid having the default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Descripciones de las marcas de comandos
Variable | Descripción |
NODE_ID | El ID asignado por el usuario de la TPU que se crea cuando se asigna la solicitud de recurso en fila. |
ID DEL PROYECTO | Google Cloud Nombre del proyecto. Usa un proyecto existente o crea uno nuevo en |
ZONA | Consulta el documento Regiones y zonas de TPU para conocer las zonas compatibles. |
ACCELERATOR_TYPE | Consulta Tipos de aceleradores. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Esta es la dirección de correo electrónico de tu cuenta de servicio que puedes encontrar en
la Google Cloud Console -> IAM -> Cuentas de servicio.
Por ejemplo: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com |
NUM_SLICES | Es la cantidad de rebanadas que se deben crear (solo es necesario para Multislice). |
QUEUED_RESOURCE_ID | El ID de texto asignado por el usuario de la solicitud de recursos en cola. |
VALID_DURATION | Es la duración durante la cual la solicitud de recursos en cola es válida. |
NETWORK_NAME | Es el nombre de una red secundaria que se usará. |
NETWORK_FW_NAME | Es el nombre de un firewall de red secundario que se usará. |
Optimizaciones del rendimiento de la red
Para obtener el mejor rendimiento, usa una red con 8,896 MTU (unidad de transmisión máxima).
De forma predeterminada, una nube privada virtual (VPC) solo proporciona una MTU de 1,460 bytes, lo que proporcionará un rendimiento de red poco óptimo. Puedes configurar la MTU de una red de VPC en cualquier valor entre 1,300 bytes y 8,896 bytes (inclusive). Los tamaños de MTU personalizados comunes son 1,500 bytes (Ethernet estándar) o 8,896 bytes (el máximo posible). Para obtener más información, consulta Tamaños válidos de MTU de la red de VPC.
Para obtener más información sobre cómo cambiar la configuración de MTU de una red existente o predeterminada, consulta Cambia la configuración de MTU de una red de VPC.
En el siguiente ejemplo, se crea una red con 8,896 MTU.
export RESOURCE_NAME=RESOURCE_NAME export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall export PROJECT=X gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} --allow tcp,icmp,udp --project=${PROJECT}
Cómo usar varias NIC (opción para Multislice)
Las siguientes variables de entorno son necesarias para una subred secundaria cuando usas un entorno de Multislice.
export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2
Usa los siguientes comandos para crear enrutamiento de IP personalizado para la red y la subred.
gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
--bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
--network="${NETWORK_NAME_2}" \
--range=10.10.0.0/18 --region="${REGION}" \
--project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
--network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
gcloud compute routers create "${ROUTER_NAME}" \
--project="${PROJECT_ID}" \
--network="${NETWORK_NAME_2}" \
--region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
--router="${ROUTER_NAME}" \
--region="${REGION}" \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project="${PROJECT_ID}" \
--enable-logging
Una vez que se crea una porción de varias redes, puedes validar que se usen ambas NIC configurando un clúster de XPK y ejecutando --command ifconfig
como parte de la carga de trabajo de XPK.
Usa el siguiente comando xpk workload
para mostrar el resultado del comando ifconfig
en los registros de la consola de Cloud y verifica que tanto eth0 como eth1 tengan mtu=8896.
python3 xpk.py workload create \ --cluster CLUSTER_NAME \ (--base-docker-image maxtext_base_image|--docker-image CLOUD_IMAGE_NAME \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command "ifconfig"
Verifica que tanto eth0 como eth1 tengan mtu=8,896. Una forma de verificar que tienes varios NIC en ejecución es ejecutar el comando --command "ifconfig" como parte de la carga de trabajo de XPK. Luego, observa el resultado impreso de esa carga de trabajo de xpk en los registros de la consola de Cloud y verifica que tanto eth0 como eth1 tengan mtu=8896.
Configuración de TCP mejorada
En el caso de las TPU creadas con la interfaz de recursos en cola, puedes ejecutar el siguiente comando para mejorar el rendimiento de la red aumentando los límites del búfer de recepción de TCP.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "$PROJECT" \ --zone "$ZONE" \ --node=all \ --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \ --worker=all
Aprovisionamiento con recursos en cola
La capacidad asignada se puede aprovisionar con el comando create
de recursos en cola.
Crea una solicitud de recurso en cola de TPU.
La marca
--reserved
solo es necesaria para los recursos reservados, no para los recursos on demand.gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ [--reserved] # The following flags are only needed if you are using Multislice. --node-count node-count # Number of slices in a Multislice \ --node-prefix node-prefix # An optional user-defined node prefix; the default is QUEUED_RESOURCE_ID.
Si la solicitud de recursos en cola se crea correctamente, el estado en el campo "response" será "WAITING_FOR_RESOURCES" o "FAILED". Si la solicitud de recursos en cola está en el estado "WAITING_FOR_RESOURCES", significa que el recurso se agregó a la cola y se aprovisionará cuando haya suficiente capacidad de TPU asignada. Si la solicitud de recursos en cola está en el estado "FAILED", el motivo de la falla aparecerá en el resultado. La solicitud de recursos en cola vencerá si no se aprovisiona un v6e dentro de la duración especificada y el estado se convierte en “FAILED”. Consulta la documentación pública de Recursos en cola para obtener más información.
Cuando tu solicitud de recursos en cola esté en el estado "ACTIVO", podrás conectarte a tus VMs de TPU con SSH. Usa los comandos
list
odescribe
para consultar el estado de tu recurso en cola.gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE}
Cuando el recurso en cola está en el estado "ACTIVE", el resultado es similar al siguiente:
state: state: ACTIVE
Administra tus VMs de TPU. Para conocer las opciones de administración de tus VMs de TPU, consulta Cómo administrar VMs de TPU.
Conéctate a tus VMs de TPU con SSH
Puedes instalar objetos binarios en cada VM de TPU de tu porción de TPU y ejecutar código. Consulta la sección Tipos de VM para determinar cuántas VMs tendrá tu fragmento.
Para instalar los objetos binarios o ejecutar código, puedes usar SSH para conectarte a una VM con el comando
tpu-vm ssh
.gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --node=all # add this flag if you are using Multislice
Para usar SSH y conectarte a una VM específica, usa la marca
--worker
que sigue a un índice basado en 0:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
Si tienes formas de rebanada superiores a 8 chips, tendrás varias VM en una rebanada. En este caso, usa los parámetros
--worker=all
y--command
en el comandogcloud alpha compute tpus tpu-vm ssh
para ejecutar un comando en todas las VMs de forma simultánea. Por ejemplo:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \ --zone ${ZONE} --worker=all \ --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Borrar un recurso en cola
Borra un recurso en cola al final de la sesión o quita las solicitudes de recursos en cola que estén en el estado "FAILED". Para borrar un recurso en cola, borra la porción y, luego, la solicitud de recurso en cola en 2 pasos:
gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \ --zone=${ZONE} --quiet gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet
gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
Aprovisiona TPU v6e con GKE o XPK
Si usas comandos de GKE con v6e, puedes usar comandos de Kubernetes o XPK para aprovisionar TPU y entrenar o entregar modelos. Consulta Planifica las TPU en GKE para obtener información sobre cómo planificar tus configuraciones de TPU en los clústeres de GKE. En las siguientes secciones, se proporcionan comandos para crear un clúster de XPK con compatibilidad con una sola NIC y con varias NIC.
Comandos para crear un clúster XPK con compatibilidad con una sola NIC
export CLUSTER_NAME xpk-cluster-name export ZONE=us-central2-b export PROJECT=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network ${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster $CLUSTER_NAME \ --cluster-cpu-machine-type=n1-standard-8 \ --num-slices=$NUM_SLICES \ --tpu-type=$TPU_TYPE \ --zone=$ZONE \ --project=$PROJECT \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
Descripciones de las marcas de comandos
Variable | Descripción |
CLUSTER_NAME | Es el nombre asignado por el usuario para el clúster de XPK. |
ID DEL PROYECTO | Google Cloud Nombre del proyecto. Usa un proyecto existente o crea uno nuevo en |
ZONA | Consulta el documento Regiones y zonas de TPU para conocer las zonas compatibles. |
TPU_TYPE | Consulta Tipos de aceleradores. |
NUM_SLICES | La cantidad de divisiones que deseas crear |
CLUSTER_ARGUMENTS | La red y la subred que se usarán.
Por ejemplo: "--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}" |
NUM_SLICES | Es la cantidad de rebanadas que se crearán. |
NETWORK_NAME | Es el nombre de una red secundaria que se usará. |
NETWORK_FW_NAME | Es el nombre de un firewall de red secundario que se usará. |
Comandos para crear un clúster de XPK con compatibilidad con varias NIC
export CLUSTER_NAME xpk-cluster-name export ZONE=us-central2-b export PROJECT=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export exportSUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create "${NETWORK_NAME_1}" \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_1}" \ --network="${NETWORK_NAME_1}" \ --range=10.11.0.0/18 \ --region="${REGION}" \ --project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \ --network "${NETWORK_NAME_1}" \ --allow tcp,icmp,udp \ --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \ --project="${PROJECT}" \ --network="${NETWORK_NAME_1}" \ --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \ --router="${ROUTER_NAME}" \ --region="${REGION}" \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project="${PROJECT}" \ --enable-logging
Secondary subnet for multi-nic experience. Need custom ip routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create "${NETWORK_NAME_2}" \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \ --network="${NETWORK_NAME_2}" \ --range=10.10.0.0/18 \ --region="${REGION}" \ --project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \ --network "${NETWORK_NAME_2}" \ --allow tcp,icmp,udp \ --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \ --project="${PROJECT}" \ --network="${NETWORK_NAME_2}" \ --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \ --router="${ROUTER_NAME}" \ --region="${REGION}" \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project="${PROJECT}" \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster $CLUSTER_NAME \
--num-slices=$NUM_SLICES \
--tpu-type=$TPU_TYPE \
--zone=$ZONE \
--project=$PROJECT \
--on-demand \
--custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
--custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
--create-vertex-tensorboard
Descripciones de las marcas de comandos
Variable | Descripción |
CLUSTER_NAME | Es el nombre asignado por el usuario para el clúster de XPK. |
ID DEL PROYECTO | Google Cloud Nombre del proyecto. Usa un proyecto existente o crea uno nuevo en |
ZONA | Consulta el documento Regiones y zonas de TPU para conocer las zonas compatibles. |
TPU_TYPE | Consulta Tipos de aceleradores. |
NUM_SLICES | La cantidad de divisiones que deseas crear |
CLUSTER_ARGUMENTS | La red y la subred que se usarán.
Por ejemplo: "--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}" |
NODE_POOL_ARGUMENTS | Es la red de nodos adicional que se usará.
Por ejemplo: "--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}" |
NUM_SLICES | Es la cantidad de rebanadas que se deben crear (solo es necesario para Multislice). |
NETWORK_NAME | Es el nombre de una red secundaria que se usará. |
NETWORK_FW_NAME | Es el nombre de un firewall de red secundario que se usará. |
Configuración del framework
En esta sección, se describe el proceso de configuración general para el entrenamiento de modelos de AA con los frameworks JAX, PyTorch o TensorFlow. Puedes aprovisionar TPUs con recursos en cola o GKE. La configuración de GKE se puede realizar con XPK o comandos de Kubernetes.
Configuración para JAX
En esta sección, se proporcionan ejemplos para ejecutar cargas de trabajo de JAX en GKE, con o sin XPK, así como para usar recursos en cola.
Configura JAX con GKE
En el siguiente ejemplo, se configura un host único de 2 × 2 con un archivo YAML de Kubernetes.
Porción única en un solo host
apiVersion: v1
kind: Pod
metadata:
name: tpu-pod-jax-v6e-a
spec:
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 2x2
containers:
- name: tpu-job
image: python:3.10
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Cuando se complete correctamente, deberías ver el siguiente mensaje en el registro de GKE:
Total TPU chips: 4
Porción única en varios hosts
En el siguiente ejemplo, se configura un grupo de nodos multihost 4 × 4 con un archivo YAML de Kubernetes.
apiVersion: v1
kind: Service
metadata:
name: headless-svc
spec:
clusterIP: None
selector:
job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
name: tpu-available-chips
spec:
backoffLimit: 0
completions: 4
parallelism: 4
completionMode: Indexed
template:
spec:
subdomain: headless-svc
restartPolicy: Never
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
containers:
- name: tpu-job
image: python:3.10
ports:
- containerPort: 8471 # Default port using which TPU VMs communicate
- containerPort: 8431 # Port to export TPU runtime metrics, if supported.
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
requests:
google.com/tpu: 4
limits:
google.com/tpu: 4
Cuando se complete correctamente, deberías ver el siguiente mensaje en el registro de GKE:
Total TPU chips: 16
Porciones múltiples en varios hosts
En el siguiente ejemplo, se configuran dos grupos de nodos multihost 4 × 4 con un archivo YAML de Kubernetes.
Como requisito previo, debes instalar JobSet v0.2.3 o una versión posterior.
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: multislice-job
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
failurePolicy:
maxRestarts: 4
replicatedJobs:
- name: slice
replicas: 2
template:
spec:
parallelism: 4
completions: 4
backoffLimit: 0
template:
spec:
hostNetwork: true
dnsPolicy: ClusterFirstWithHostNet
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 4x4
hostNetwork: true
containers:
- name: jax-tpu
image: python:3.10
ports:
- containerPort: 8471
- containerPort: 8080
- containerPort: 8431
securityContext:
privileged: true
command:
- bash
- -c
- |
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
resources:
limits:
google.com/tpu: 4
requests:
google.com/tpu: 4
Cuando se complete correctamente, deberías ver el siguiente mensaje en el registro de GKE:
Total TPU chips: 32
Para obtener más información, consulta Cómo ejecutar una carga de trabajo de porciones múltiples en la documentación de GKE.
Para mejorar el rendimiento, habilita hostNetwork.
Multi-NIC
Para aprovechar la NIC múltiple en GKE, el manifiesto del pod de Kubernetes debe tener anotaciones adicionales. El siguiente es un manifiesto de ejemplo de carga de trabajo de varias NIC que no es de TPU.
apiVersion: v1
kind: Pod
metadata:
name: sample-netdevice-pod-1
annotations:
networking.gke.io/default-interface: 'eth0'
networking.gke.io/interfaces: |
[
{"interfaceName":"eth0","network":"default"},
{"interfaceName":"eth1","network":"netdevice-network"}
]
spec:
containers:
- name: sample-netdevice-pod
image: busybox
command: ["sleep", "infinity"]
ports:
- containerPort: 80
restartPolicy: Always
tolerations:
- key: "google.com/tpu"
operator: "Exists"
effect: "NoSchedule"
Si exec
en el pod de Kubernetes, deberías ver la NIC adicional con el siguiente código.
$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
inet 127.0.0.1/8 scope host lo
valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
inet 172.24.0.4/32 scope global eth1
valid_lft forever preferred_lft forever
Configura JAX con GKE y XPK
Consulta un ejemplo en el archivo README de xpk.
Para configurar y ejecutar XPK con MaxText, consulta: Cómo ejecutar MaxText.
Configura JAX con recursos en cola
Instala JAX en todas las VMs de TPU de tu porción o porciones de forma simultánea con gcloud alpha compute tpus tpu-vm ssh
. Para Multislice, agrega --node=all
.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'
Puedes ejecutar el siguiente código de Python para verificar cuántos núcleos de TPU están disponibles en tu fragmento y probar que todo esté instalado correctamente (los resultados que se muestran aquí se generaron con un fragmento v6e-16):
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
El resultado es similar a este:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count() muestra la cantidad total de chips en la porción determinada. jax.local_device_count() indica la cantidad de chips a los que puede acceder una sola VM en esta porción.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
&& pip install -r requirements.txt && pip install . '
Cómo solucionar problemas de configuración de JAX
Una sugerencia general es habilitar el registro detallado en el manifiesto de tu carga de trabajo de GKE. Luego, proporciona los registros al equipo de asistencia de GKE.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Mensajes de error
no endpoints available for service 'jobset-webhook-service'
Este error significa que el conjunto de trabajos no se instaló correctamente. Verifica si se están ejecutando los pods de Kubernetes de la implementación de jobset-controller-manager. Para obtener más información, consulta la documentación de solución de problemas de JobSet.
TPU initialization failed: Failed to connect
Asegúrate de que la versión de tu nodo de GKE sea 1.30.4-gke.1348000 o posterior (no se admite GKE 1.31).
Configuración para PyTorch
En esta sección, se describe cómo comenzar a usar PJRT en v6e con PyTorch/XLA. La versión recomendada de Python es 3.10.
Configura PyTorch con GKE y XPK
Puedes usar el siguiente contenedor de Docker con XPK, que ya tiene instaladas las dependencias de PyTorch:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Para crear una carga de trabajo de XPK, usa el siguiente comando:
python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES} \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
El uso de --base-docker-image
crea una nueva imagen de Docker con el directorio de trabajo
actual integrado en el nuevo Docker.
Configura PyTorch con recursos en cola
Sigue estos pasos para instalar PyTorch con recursos en cola y ejecutar una pequeña secuencia de comandos en v6e.
Instala dependencias con SSH para acceder a las VMs
Para Multislice, agrega --node=all
:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt install -y libopenblas-base pip3 \
install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Mejora el rendimiento de los modelos con asignaciones grandes y frecuentes
En el caso de los modelos que tienen asignaciones frecuentes y de tamaño, observamos que el uso de tcmalloc
mejora el rendimiento de manera significativa en comparación con la implementación predeterminada de malloc
, por lo que el malloc
predeterminado que se usa en la VM de TPU es tcmalloc
. Sin embargo, según tu carga de trabajo (por ejemplo, con DLRM, que tiene asignaciones muy grandes para sus tablas de incorporación), tcmalloc
puede causar una ralentización, en cuyo caso puedes intentar restablecer la siguiente variable con el malloc
predeterminado:
unset LD_PRELOAD
Usa una secuencia de comandos de Python para realizar un cálculo en la VM v6e:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--project ${PROJECT_ID} \
--zone ${ZONE} --worker all --command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'
Esto genera un resultado similar al que se muestra a continuación:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
Configuración para TensorFlow
Para la versión preliminar pública v6e, solo se admite la versión del entorno de ejecución tf-nightly.
Para restablecer tpu-runtime
con la versión compatible con TensorFlow v6e, ejecuta los siguientes comandos:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Usa SSH para acceder a worker-0:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE}
Instala TensorFlow en worker-0:
sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Exporta la variable de entorno TPU_NAME
:
export TPU_NAME=v6e-16
Puedes ejecutar la siguiente secuencia de comandos de Python para verificar cuántos núcleos de TPU están disponibles en tu fragmento y probar que todo esté instalado correctamente (los resultados que se muestran se generaron con un fragmento v6e-16):
import TensorFlow as tf
print("TensorFlow version " + tf.__version__)
@tf.function
def add_fn(x,y):
z = x + y
return z
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
x = tf.constant(1.)
y = tf.constant(1.)
z = strategy.run(add_fn, args=(x,y))
print(z)
El resultado es similar a este:
PerReplica:{
0: tf.Tensor(2.0, shape=(), dtype=float32),
1: tf.Tensor(2.0, shape=(), dtype=float32),
2: tf.Tensor(2.0, shape=(), dtype=float32),
3: tf.Tensor(2.0, shape=(), dtype=float32),
4: tf.Tensor(2.0, shape=(), dtype=float32),
5: tf.Tensor(2.0, shape=(), dtype=float32),
6: tf.Tensor(2.0, shape=(), dtype=float32),
7: tf.Tensor(2.0, shape=(), dtype=float32)
}
v6e con SkyPilot
Puedes usar TPU v6e con SkyPilot. Sigue estos pasos para agregar información de ubicación o precios relacionada con la v6e a SkyPilot.
Agrega lo siguiente al final de
~/.sky/catalogs/v5/gcp/vms.csv
:,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
Especifica los siguientes recursos en un archivo YAML:
# tpu_v6.yaml resources: accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use accelerator_args: runtime_version: v2-alpha-tpuv6e # Official suggested runtime
Inicia un clúster con TPU v6e:
sky launch tpu_v6.yaml -c tpu_v6
Conéctate a la TPU v6e con SSH:
ssh tpu_v6
Instructivos de inferencia
En los siguientes instructivos, se muestra cómo ejecutar la inferencia en TPU v6e:
Ejemplos de entrenamiento
En las siguientes secciones, se proporcionan ejemplos para entrenar modelos de MaxText, MaxDiffusion y PyTorch en TPU v6e.
Entrenamiento de MaxText y MaxDiffusion en la VM de Cloud TPU v6e
En las siguientes secciones, se describe el ciclo de vida de entrenamiento de los modelos MaxText y MaxDiffusion.
En general, los pasos de alto nivel son los siguientes:
- Compila la imagen base de la carga de trabajo.
- Ejecuta tu carga de trabajo con XPK.
- Compila el comando de entrenamiento para la carga de trabajo.
- Implementa la carga de trabajo.
- Sigue la carga de trabajo y consulta las métricas.
- Borra la carga de trabajo de XPK si no la necesitas.
- Borra el clúster de XPK cuando ya no lo necesites.
Compila la imagen base
Instala MaxText o MaxDiffusion y compila la imagen de Docker:
Clona el repositorio que deseas usar y cambia al directorio del repositorio:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
Configura Docker para usar Google Cloud CLI:
gcloud auth configure-docker
Compila la imagen de Docker con el siguiente comando o con la pila estable de JAX. Para obtener más información sobre JAX Stable Stack, consulta Cómo compilar una imagen de Docker con JAX Stable Stack.
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.37
Si inicias la carga de trabajo desde una máquina que no tiene la imagen compilada de forma local, sube la imagen:
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
Compila una imagen de Docker con JAX Stable Stack
Puedes compilar las imágenes de Docker de MaxText y MaxDiffusion con la imagen base de la pila estable de JAX.
JAX Stable Stack proporciona un entorno coherente para MaxText y MaxDiffusion, ya que agrupa JAX con paquetes principales, como orbax
, flax
y optax
, junto con una libtpu.so bien calificada que impulsa las utilidades del programa de TPU y otras herramientas esenciales. Estas bibliotecas se prueban para garantizar la compatibilidad y proporcionar una base estable para compilar y ejecutar MaxText y MaxDiffusion. Esto elimina los posibles conflictos debido a versiones de paquetes incompatibles.
La pila estable de JAX incluye una libtpu.so completamente publicada y calificada, la biblioteca principal que impulsa la compilación, la ejecución y la configuración de la red de ICI del programa de TPU. La versión de libtpu reemplaza la compilación nocturna que usaba JAX anteriormente y garantiza una funcionalidad coherente de los cálculos de XLA en TPU con pruebas de calificación a nivel de PJRT en IR de HLO/StableHLO.
Para compilar la imagen de Docker de MaxText y MaxDiffusion con la pila estable de JAX, cuando
ejecutes la secuencia de comandos docker_build_dependency_image.sh
, establece la variable MODE
en stable_stack
y la variable BASEIMAGE
en la imagen base que deseas
usar.
En el siguiente ejemplo, se especifica us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
como la imagen base:
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1
Para obtener una lista de las imágenes base de JAX Stable Stack disponibles, consulta Imágenes de JAX Stable Stack en Artifact Registry.
Ejecuta tu carga de trabajo con XPK
Establece las siguientes variables de entorno si no usas los valores predeterminados establecidos por MaxText o MaxDiffusion:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Compila la secuencia de comandos de tu modelo. Esta secuencia de comandos se copiará como un comando de entrenamiento en un paso posterior.
Aún no ejecutes la secuencia de comandos del modelo.
MaxText
MaxText es un LLM de código abierto, altamente escalable y de alto rendimiento, escrito en Python y JAX puros, y se orienta a Google Cloud TPU y GPUs para el entrenamiento y la inferencia.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma es una familia de LLM con pesos abiertos que desarrolló Google DeepMind, basada en la investigación y la tecnología de Gemini.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral es un modelo de IA de vanguardia desarrollado por Mistral AI, que utiliza una arquitectura dispersa de mezcla de expertos (MoE).
python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama es una familia de LLM de ponderación abierta que desarrolló Meta.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=llama3-8b \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ tokenizer_path=assets/tokenizer_llama3.tiktoken \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ attention=flash"
MaxDiffusion
MaxDiffusion es una colección de implementaciones de referencia de varios modelos de difusión latentes escritos en Python y JAX puros que se ejecutan en dispositivos XLA, incluidas las GPUs y las Cloud TPU. Stable Diffusion es un modelo latente de texto a imagen que genera imágenes fotorrealistas a partir de cualquier entrada de texto.
Debes instalar una rama específica de Git para ejecutar MaxDiffusion, como se muestra en el siguiente comando
git checkout
.git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 && pip install -r requirements.txt && pip install .
Secuencia de comandos de entrenamiento:
cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \ python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash run_name=sdxl-ddp-v6e
Ejecuta el modelo con la secuencia de comandos que creaste en el paso anterior. Debes especificar la marca
--base-docker-image
para usar la imagen base de MaxText o especificar la marca--docker-image
y la imagen que deseas usar.Opcional: Puedes habilitar el registro de depuración si incluyes la marca
--enable-debug-logs
. Para obtener más información, consulta Cómo depurar JAX en MaxText.Opcional: Puedes crear un experimento de Vertex AI para subir datos a Vertex AI TensorBoard si incluyes la marca
--use-vertex-tensorboard
. Para obtener más información, consulta Supervisa JAX en MaxText con Vertex AI.python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=$ACCELERATOR_TYPE \ --num-slices=$NUM_SLICES \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command $YOUR-MODEL-SCRIPT
Exporta las siguientes variables:
export ClUSTER_NAME=CLUSTER_NAME: Es el nombre de tu clúster de XPK. export ACCELERATOR_TYPEACCELERATOR_TYPE: Es la versión y el tamaño de tu TPU. Por ejemplo,
v6e-256
. export NUM_SLICES=NUM_SLICES: Es la cantidad de porciones de TPU. export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: Es la secuencia de comandos del modelo que se ejecutará como un comando de entrenamiento.El resultado incluye un vínculo para seguir tu carga de trabajo, similar al siguiente:
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
Abre el vínculo y haz clic en la pestaña Registros para hacer un seguimiento de tu carga de trabajo en tiempo real.
Cómo depurar JAX en MaxText
Usa comandos XPK complementarios para diagnosticar por qué no se ejecuta el clúster o la carga de trabajo.
- Lista de cargas de trabajo de XPK
- Inspector de XPK
- Habilita el registro detallado en tus registros de cargas de trabajo con la marca
--enable-debug-logs
cuando crees la carga de trabajo de XPK.
Supervisa JAX en MaxText con Vertex AI
Consulta los datos escalares y de perfil a través de TensorBoard administrado de Vertex AI.
- Aumenta las solicitudes de administración de recursos (CRUD) de la zona que usas de 600 a 5,000. Esto podría no ser un problema para cargas de trabajo pequeñas que usan menos de 16 VMs.
Instala dependencias como
cloud-accelerator-diagnostics
para Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Crea tu clúster de XPK con la marca
--create-vertex-tensorboard
, como se documenta en Crea Vertex AI TensorBoard. También puedes ejecutar este comando en clústeres existentes.Crea tu experimento de Vertex AI cuando ejecutes tu carga de trabajo de XPK con la marca
--use-vertex-tensorboard
y la marca opcional--experiment-name
. Para obtener la lista completa de pasos, consulta Crea un experimento de Vertex AI para subir datos a Vertex AI TensorBoard.
Los registros incluyen un vínculo a un Vertex AI TensorBoard, similar al siguiente:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
También puedes encontrar el vínculo de Vertex AI TensorBoard en la console de Google Cloud. Ve a Experimentos de Vertex AI en la consola de Google Cloud. Selecciona la región adecuada en el menú desplegable.
El directorio de TensorBoard también se escribe en el bucket de Cloud Storage que especificaste con ${BASE_OUTPUT_DIR}
.
Borra cargas de trabajo de XPK
Usa el comando xpk workload delete
para borrar una o más cargas de trabajo según el prefijo o el estado del trabajo. Este comando puede ser útil si enviaste cargas de trabajo de XPK que ya no necesitan ejecutarse o si tienes trabajos que están detenidos en la cola.
Borra el clúster de XPK
Usa el comando xpk cluster delete
para borrar un clúster:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone $ZONE --project $PROJECT_ID
Entrenamiento de Llama y PyTorch/XLA en la VM de Cloud TPU v6e
En este instructivo, se describe cómo entrenar modelos de Llama con PyTorch/XLA en TPU v6e con el conjunto de datos de WikiText.
Obtén acceso a Hugging Face y al modelo Llama 3
Necesitas un token de acceso de usuario de Hugging Face para ejecutar este instructivo. Para obtener información sobre cómo crear tokens de acceso de usuario, consulta la documentación de Hugging Face sobre los tokens de acceso de usuario.
También necesitas permiso para acceder al modelo Llama 3 8B en Hugging Face. Para obtener acceso, ve al modelo Meta-Llama-3-8B en Hugging Face y solicita acceso.
Crea una VM de TPU
Crea una TPU v6e con 8 chips para ejecutar el instructivo.
Configure las variables de entorno:
export ACCELERATOR_TYPE=v6e-8 export VERSION=v2-alpha-tpuv6e export TPU_NAME=$USER-$ACCELERATOR_TYPE export PROJECT=YOUR_PROJECT export ZONE=YOUR_ZONE
Crea una VM de TPU:
gcloud alpha compute tpus tpu-vm create $TPU_NAME --version=$VERSION \ --accelerator-type=$ACCELERATOR_TYPE --zone=$ZONE --project=$PROJECT
Instalación
Instala la división pytorch-tpu/transformers
de Hugging Face Transformers y las dependencias. Este instructivo se probó con las siguientes versiones de dependencias que se usan en este ejemplo:
torch
: Compatible con 2.5.0torch_xla[tpu]
: Compatible con 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT --zone $ZONE \ --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Configura los parámetros de configuración del modelo
El comando de entrenamiento en la siguiente sección, Run the model, usa dos archivos de configuración JSON para definir los parámetros del modelo y la configuración de FSDP (paralelismo de datos completamente fragmentado). El fragmentación de FSDP se usa para que los pesos del modelo se ajusten a un tamaño de lote más grande durante el entrenamiento. Cuando se entrena con modelos más pequeños, puede ser suficiente usar el paralelismo de datos y replicar los pesos en cada dispositivo. Para obtener más información sobre cómo dividir tensores en varios dispositivos en PyTorch/XLA, consulta la Guía del usuario de SPMD de PyTorch/XLA.
Crea el archivo de configuración del parámetro del modelo. La siguiente es la configuración del parámetro del modelo para Llama3-8B. Para otros modelos, busca la configuración en Hugging Face. Por ejemplo, consulta la configuración de Llama2-7B.
cat > llama-config.json <
{ "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF Crea el archivo de configuración de FSDP:
cat > fsdp-config.json <
{ "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF Para obtener más información sobre FSDP, consulta FSDPv2.
Sube los archivos de configuración a tus VMs de TPU con el siguiente comando:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \ --worker=all \ --project=$PROJECT \ --zone $ZONE
Ejecuta el modelo
Con los archivos de configuración que creaste en la sección anterior, ejecuta la secuencia de comandos run_clm.py
para entrenar el modelo Llama 3 8B en el conjunto de datos de WikiText. La secuencia de comandos de entrenamiento tarda aproximadamente 10 minutos en ejecutarse en una TPU v6e-8.
Accede a Hugging Face en tu TPU con el siguiente comando:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Ejecuta el entrenamiento de modelos:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Soluciona problemas de PyTorch/XLA
Si configuras las variables opcionales para la depuración en la sección anterior, el perfil del modelo se almacenará en la ubicación que especifique la variable PROFILE_LOGDIR
. Puedes extraer el archivo xplane.pb
almacenado en esta ubicación y usar tensorboard
para ver los perfiles en tu navegador con las instrucciones de TensorBoard. Si PyTorch/XLA no funciona como se espera, consulta la guía de solución de problemas, que tiene sugerencias para depurar, generar perfiles y optimizar tu modelo.
Entrenamiento de DLRM DCN v2 en v6e
En este instructivo, se muestra cómo entrenar el modelo DLRM DCN v2 en TPU v6e. Debes aprovisionar una TPU v6e con 64, 128 o 256 chips.
Si ejecutas en varios hosts, ejecuta el siguiente comando para restablecer tpu-runtime
con la versión correcta de TensorFlow: Si ejecutas la app en un solo host, no necesitas ejecutar los siguientes dos comandos.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID}
--zone ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Establece una conexión SSH a worker-0
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}
Establece el nombre de la TPU
export TPU_NAME=${TPU_NAME}
Ejecuta DLRM v2
pip install --user setuptools==65.5.0
pip install cloud-tpu-client
pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'
TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py --mode=train --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'mixed_bfloat16'
task:
use_synthetic_data: false
use_tf_record_reader: true
train_data:
input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
global_batch_size: 16384
use_cached_data: true
validation_data:
input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
global_batch_size: 16384
use_cached_data: true
model:
num_dense_features: 13
bottom_mlp: [512, 256, 128]
embedding_dim: 128
interaction: 'multi_layer_dcn'
dcn_num_layers: 3
dcn_low_rank_dim: 512
size_threshold: 8000
top_mlp: [1024, 1024, 512, 256, 1]
use_multi_hot: true
concat_dense: false
dcn_use_bias: true
vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
max_ids_per_chip_per_sample: 128
max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
use_partial_tpu_embedding: false
size_threshold: 0
initialize_tables_on_host: true
trainer:
train_steps: 10000
validation_interval: 1000
validation_steps: 660
summary_interval: 1000
steps_per_loop: 1000
checkpoint_interval: 0
optimizer_config:
embedding_optimizer: 'Adagrad'
dense_optimizer: 'Adagrad'
lr_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.025
warmup_steps: 0
dense_sgd_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.00025
warmup_steps: 8000
train_tf_function: true
train_tf_while_loop: true
eval_tf_while_loop: true
use_orbit: true
pipeline_sparse_and_dense_execution: true"
Ejecuta script.sh
:
chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Las siguientes marcas son necesarias para ejecutar cargas de trabajo de recomendación (DCN de DLRM):
ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"
Resultados de comparativas
En la siguiente sección, se incluyen los resultados de las comparativas de DLRM DCN v2 y MaxDiffusion en v6e.
DLRM DCN v2
La secuencia de comandos de entrenamiento de DLRM DCN v2 se ejecutó a diferentes escalas. Consulta las tasas de transferencia en la siguiente tabla.
v6e-64 | v6e-128 | v6e-256 | |
Pasos de entrenamiento | 7000 | 7000 | 7000 |
Tamaño del lote global | 131072 | 262144 | 524288 |
Capacidad de procesamiento (ejemplos/s) | 2975334 | 5111808 | 10066329 |
MaxDiffusion
Ejecutamos la secuencia de comandos de entrenamiento para MaxDiffusion en un v6e-4, un v6e-16 y un 2xv6e-16. Consulta las tasas de transferencia en la siguiente tabla.
v6e-4 | v6e-16 | Dos v6e-16 | |
Pasos de entrenamiento | 0.069 | 0.073 | 0.13 |
Tamaño del lote global | 8 | 32 | 64 |
Capacidad de procesamiento (ejemplos/s) | 115.9 | 438.4 | 492.3 |
Programación de la recopilación
Trillium (v6e) incluye una nueva función llamada "programación de colecciones". Esta función ofrece una forma de administrar varias porciones de TPU que ejecutan una carga de trabajo de inferencia de host único en GKE y la API de Cloud TPU. Agrupar estas porciones en una colección facilita el ajuste de la cantidad de réplicas para que coincidan con la demanda. Las actualizaciones de software se controlan con cuidado para garantizar que una parte de las divisiones dentro de la colección siempre esté disponible para controlar el tráfico entrante.
Consulta la documentación de GKE para obtener más información sobre el uso de la programación de colecciones con GKE.
La función de programación de colecciones solo se aplica a la versión 6e.
Usa la programación de colecciones desde la API de Cloud TPU
Una colección de un solo host en la API de Cloud TPU es un recurso en cola en el que se establece una marca especial (--workload-type = availability-optimized
) para indicarle a la infraestructura subyacente que se debe usar para entregar cargas de trabajo.
El siguiente comando aprovisiona una colección de host único con la API de Cloud TPU:
gcloud alpha compute tpus queued-resources create my-collection \ --project=$PROJECT_ID \ --zone=${ZONE} \ --accelerator-type $ACCELERATOR_TYPE \ --node-count ${NODE_COUNT} \ --workload-type=availability-optimized
Supervisa y crea perfiles
Cloud TPU v6e admite la supervisión y la generación de perfiles con los mismos métodos que las generaciones anteriores de Cloud TPU. Para obtener más información sobre la supervisión, consulta Supervisa VMs de TPU.