Entrenar un modelo con TPU v6e
En este documento se explica cómo entrenar modelos en Cloud TPU v6e (también llamada Trillium), lo que incluye la configuración del entorno, la optimización del rendimiento y ejemplos prácticos de entrenamiento con JAX y PyTorch/XLA.
La TPU v6e, también llamada Trillium, es la TPU de sexta generación de Google. En todas las plataformas técnicas, como la API y los registros, y a lo largo de este documento, Trillium se denominará v6e. Con 256 chips por Pod, la arquitectura de la versión 6e de TPU comparte muchas similitudes con la versión 5e. La TPU v6e está optimizada para el entrenamiento, el ajuste y el servicio de transformadores, modelos de texto a imagen y redes neuronales convolucionales (CNNs). Para obtener más información sobre la arquitectura y las configuraciones del sistema de TPU v6e, consulta TPU v6e.
Para obtener información sobre cómo ejecutar inferencias en la versión 6e de TPU de Cloud, consulta los siguientes tutoriales:
- Inferencia de MaxText de JetStream en la versión 6e
- Inferencia de PyTorch de JetStream en v6e
- Inferencia de MaxDiffusion en v6e
- Inferencia de vLLM en v6e
- Realizar inferencias multihost con Pathways
Antes de empezar
Antes de empezar, debes hacer lo siguiente:
- Crear una Google Cloud cuenta y un proyecto con la facturación habilitada
- Instalar los componentes alfa de Google Cloud CLI
- Habilitar la API de Cloud TPU
- Crear un agente de servicio de TPU de Cloud
- Crear una cuenta de servicio de TPU de Cloud y conceder permisos
Para obtener más información, consulta Configurar el entorno de TPU de Cloud.
Verificar la cuota y los permisos
Comprueba que tu proyecto tenga las siguientes cuotas:
- Cuota de TPU v6e no garantizadas o bajo demanda
- Cuota de direcciones IP
Cuota de Hyperdisk Balanced y de cualquier otro tipo de disco que quieras usar
Si usas GKE con XPK, necesitas permisos adicionales en la Google Cloud consola. Para obtener más información, consulta Permisos necesarios en la consolaGoogle Cloud .
Provisionar TPUs
Puedes aprovisionar y gestionar TPUs v6e con los siguientes métodos:
- GKE puedes usar GKE para aprovisionar y gestionar las TPUs como un grupo de aceleradores para tus cargas de trabajo de aprendizaje automático en contenedores. Para obtener más información, consulta el artículo Acerca de las TPUs en GKE.
- GKE y XPK: XPK es una herramienta de línea de comandos que simplifica la creación de clústeres y la ejecución de cargas de trabajo en GKE. Está diseñado para que los profesionales del aprendizaje automático puedan aprovisionar TPUs y ejecutar trabajos de entrenamiento sin necesidad de tener un gran dominio de Kubernetes. Para obtener más información, consulta el repositorio de GitHub de XPK.
- Recursos en cola de TPU de Cloud: los recursos en cola te permiten solicitar capacidad de TPU que se aprovisiona cuando está disponible. Es ideal para tareas por lotes y cargas de trabajo tolerantes a fallos que pueden esperar en una cola. Puedes especificar un periodo para tu solicitud. Para obtener más información, consulta Gestionar recursos en cola.
Provisionar TPUs de Cloud v6e con GKE y XPK
Si usas comandos de GKE con v6e, puedes usar comandos de Kubernetes o XPK para aprovisionar Cloud TPUs y entrenar o servir modelos. Consulta Planificar las TPUs de Cloud en GKE para saber cómo planificar las configuraciones de TPUs de Cloud en clústeres de GKE. En las siguientes secciones se proporcionan comandos para crear un clúster XPK con compatibilidad con una sola NIC y con varias NIC.
Crear un clúster XPK con compatibilidad con una sola NIC
export CLUSTER_NAME=xpk-cluster-name export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network=${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
Descripciones de marcas de comandos
Variable | Descripción |
CLUSTER_NAME | Nombre asignado por el usuario al clúster XPK. |
PROJECT_ID | Google Cloud Nombre del proyecto. Usa un proyecto que ya tengas o crea uno. Para obtener más información, consulta el artículo Configurar un Google Cloud proyecto. |
ZONE | Consulta el documento Regiones y zonas de TPU de Cloud para ver las zonas admitidas. |
TPU_TYPE | Consulta Tipos de aceleradores. |
NUM_SLICES | El número de porciones que quieres crear |
CLUSTER_ARGUMENTS | La red y la subred que se van a usar.
Por ejemplo: |
NUM_SLICES | Número de porciones que se van a crear. |
NETWORK_NAME | Nombre de una red secundaria que se va a usar. |
NETWORK_FW_NAME | Nombre de un cortafuegos de red secundario que se va a usar. |
Crear un clúster de XPK con compatibilidad con varias NICs
export CLUSTER_NAME=xpk-cluster-name export REGION=your-region export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \ --network=${NETWORK_NAME_1} \ --range=10.11.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_1} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_1} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \ --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \ --create-vertex-tensorboard
Descripciones de marcas de comandos
Variable | Descripción |
CLUSTER_NAME | Nombre asignado por el usuario al clúster XPK. |
PROJECT_ID | Google Cloud Nombre del proyecto. Usa un proyecto que ya tengas o crea uno. Para obtener más información, consulta el artículo Configurar un Google Cloud proyecto. |
ZONE | Consulta el documento Regiones y zonas de TPU de Cloud para ver las zonas admitidas. |
TPU_TYPE | Consulta Tipos de aceleradores. |
NUM_SLICES | El número de porciones que quieres crear |
CLUSTER_ARGUMENTS | La red y la subred que se van a usar.
Por ejemplo: |
NODE_POOL_ARGUMENTS | Red de nodos adicional que se va a utilizar.
Por ejemplo: |
NUM_SLICES | Número de segmentos que se van a crear (solo es necesario para Multislice). |
NETWORK_NAME | Nombre de una red secundaria que se va a usar. |
NETWORK_FW_NAME | Nombre de un cortafuegos de red secundario que se va a usar. |
Configurar JAX o PyTorch
En los siguientes recursos se muestra cómo configurar JAX o PyTorch en tu TPU de Cloud, en función del método de aprovisionamiento y gestión que utilices:
- Autopilot de GKE: prepara tu aplicación de TPU
- GKE Standard: prepara tus cargas de trabajo
- GKE y XPK: README de XPK
- TPU de Cloud de un solo host con JAX: ejecutar un cálculo en una VM de TPU de Cloud con JAX
- TPU de Cloud multihost con JAX: ejecutar código JAX en sectores de TPU
- TPU de Cloud de un solo host con PyTorch: ejecutar un cálculo en una máquina virtual de TPU de Cloud con PyTorch
- TPU de Cloud multihost con PyTorch: ejecutar código de PyTorch en slices de TPU
Para configurar y ejecutar XPK con MaxText, consulta Ejecutar MaxText a gran escala con XPK .
Optimizar el rendimiento de la red
En esta sección se describe cómo optimizar el rendimiento de su red configurando la unidad de transmisión máxima (MTU), usando varias NICs en entornos Multislice y mejorando los ajustes de TCP.
Configurar la MTU
Para obtener el mejor rendimiento de la red, utiliza una red con una MTU (unidad de transmisión máxima) de 8896.
De forma predeterminada, una nube privada virtual (VPC) solo proporciona una MTU de 1460 bytes, lo que ofrece un rendimiento de red no óptimo. Puedes definir el valor de MTU de una red VPC entre 1300 y 8896 bytes (ambos incluidos). Los tamaños de MTU personalizados habituales son 1500 bytes (Ethernet estándar) u 8896 bytes (el máximo posible). Para obtener más información, consulta Tamaños de MTU válidos para redes VPC.
Para obtener más información sobre cómo cambiar el ajuste de MTU de una red predeterminada o de una red que ya tengas, consulta Cambiar el ajuste de MTU de una red de VPC.
En el siguiente ejemplo, se crea una red con un MTU de 8896 y una regla de cortafuegos correspondiente que permite el tráfico TCP, ICMP y UDP en la red.
export RESOURCE_NAME=your-resource-name export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \ --allow tcp,icmp,udp --project=${PROJECT_ID}
Sustituye your-resource-name por un nombre base para la red y el cortafuegos.
Usar la opción de varias NICs para Multislice
Si usas un entorno Multislice, define las siguientes variables de entorno, que son obligatorias para una subred secundaria:
export NETWORK_NAME_2=${RESOURCE_NAME} export SUBNET_NAME_2=${RESOURCE_NAME} export FIREWALL_RULE_NAME=${RESOURCE_NAME} export ROUTER_NAME=${RESOURCE_NAME}-network-2 export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2 export REGION=your-region
Usa los siguientes comandos para crear un enrutamiento de IP personalizado para la red y la subred.
Crea la red secundaria.
gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \ --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
Crea una subred para la red secundaria.
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 --region=${REGION} \ --project=${PROJECT_ID}
Crea una regla de cortafuegos para permitir el tráfico en la nueva subred.
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \ --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
Crea un router de Cloud Router para la red secundaria.
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
Crea una configuración de NAT para el router de Cloud Router.
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
Después de crear una partición de red múltiple, puedes validar que se están usando ambas tarjetas de interfaz de red (NICs) configurando un clúster XPK y añadiendo la marca --command ifconfig
al comando de creación de carga de trabajo XPK.
Usa el siguiente comando
workload create
para mostrar el resultado del comandoifconfig
en los registros de la consola y comprueba que tanto eth0 como eth1 tienen el valor 8896 en MTU. Google Cloudpython3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command "ifconfig"
Si quieres habilitar los registros de depuración o usar Vertex AI TensorBoard, añade los siguientes argumentos opcionales al comando:
--enable-debug-logs \ --use-vertex-tensorboard
Comprueba que tanto eth0 como eth1 tengan el MTU configurado en 8896. Para ello,consulta la salida de la carga de trabajo de XPK en los registros de la consola Google Cloud .
Mejorar la configuración de TCP
Si has aprovisionado tus Cloud TPUs mediante recursos en cola, puedes ejecutar el siguiente comando para mejorar el rendimiento de la red aumentando los límites del búfer de recepción de TCP.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "${PROJECT_ID}" \ --zone "${ZONE}" \ --node=all \ --worker=all \ --command=' sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'
Optimizar el rendimiento de la asignación de memoria
La biblioteca tcmalloc
se usa de forma predeterminada en las VMs de TPU de Cloud para mejorar el rendimiento de los modelos con asignaciones de memoria frecuentes y de gran tamaño. Esto se configura mediante la variable de entorno LD_PRELOAD
.
Sin embargo, en algunas cargas de trabajo (por ejemplo, DLRM con asignaciones de tablas de inserciones muy grandes), tcmalloc
puede provocar una ralentización. En estos casos, puedes volver a la función malloc
estándar anulando la variable LD_PRELOAD
en tu sesión de shell antes de ejecutar la secuencia de comandos de entrenamiento:
unset LD_PRELOAD
Usar SkyPilot
Puedes usar Cloud TPU v6e con SkyPilot. SkyPilot es un framework de código abierto que simplifica el proceso de ejecutar, gestionar y escalar cargas de trabajo de IA. Puedes añadir información sobre la ubicación y los precios relacionados con v6e a SkyPilot. Para obtener más información, consulta el ejemplo de TPU v6e de SkyPilot.
Ejemplos de entrenamiento
En las siguientes secciones se ofrecen ejemplos para entrenar modelos de MaxText, MaxDiffusion y PyTorch en Cloud TPU v6e.
Estos ejemplos se han probado con las siguientes versiones de software:
- Python
3.10
o versiones posteriores - Versiones de software nocturnas:
- JAX nocturno
0.4.32.dev20240912
- LibTPU nocturno
0.1.dev20240912+nightly
- JAX nocturno
- Versiones de software estables:
- JAX + JAX Lib de la versión 0.4.37
Entrenar MaxText y MaxDiffusion en TPU v6e de Cloud
En las siguientes secciones se describe el ciclo de vida del entrenamiento de los modelos MaxText y MaxDiffusion.
En general, los pasos que debe seguir son los siguientes:
- Crea la imagen base de la carga de trabajo.
- Ejecuta tu carga de trabajo con XPK.
- Crea el comando de entrenamiento para la carga de trabajo.
- Despliega la carga de trabajo.
- Sigue la carga de trabajo y consulta las métricas.
- Elimina la carga de trabajo XPK si no es necesaria.
- Elimina el clúster XPK cuando ya no lo necesites.
Crear imagen base
Instala MaxText o MaxDiffusion y crea la imagen de Docker:
Clona el repositorio que quieras usar y cambia al directorio del repositorio:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
Configura Docker para usar la CLI de Google Cloud:
gcloud auth configure-docker
Crea la imagen de Docker con el siguiente comando o con una imagen de IA de JAX. Para obtener más información sobre las imágenes de JAX AI, consulta Imágenes de JAX AI.
MaxText:
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
MaxDiffusion:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
Define el ID de tu proyecto en la configuración activa de gcloud CLI:
gcloud config set project ${PROJECT_ID}
Si vas a iniciar la carga de trabajo desde una máquina que no tiene la imagen compilada localmente, sube la imagen.
Define la variable de entorno
CLOUD_IMAGE_NAME
:export CLOUD_IMAGE_NAME=${USER}_runner
Sube la imagen:
bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
Ejecutar una carga de trabajo con XPK
Define las siguientes variables de entorno si no usas los valores predeterminados definidos por MaxText o MaxDiffusion:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Crea la secuencia de comandos del modelo. Esta secuencia de comandos se copiará como comando de entrenamiento en un paso posterior.
No ejecutes la secuencia de comandos del modelo todavía.
MaxText
MaxText es un LLM de código abierto de alto rendimiento y alta escalabilidad escrito en Python y JAX puros, y orientado a Google Cloud TPUs y GPUs para el entrenamiento y la inferencia.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma es una familia de LLMs de código abierto desarrollada por Google DeepMind a partir de la investigación y la tecnología de Gemini.
python3 -m MaxText.train MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral es un modelo de IA de vanguardia desarrollado por Mistral AI que utiliza una arquitectura de Mixture-of-Experts (MoE) dispersa.
python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama es una familia de LLMs de pesos abiertos desarrollada por Meta.
Para ver un ejemplo de cómo ejecutar Llama 3 en PyTorch, consulta los modelos torch_xla del repositorio torchprime.
MaxDiffusion
MaxDiffusion es una colección de implementaciones de referencia de varios modelos de difusión latente escritos en Python y JAX puros que se ejecutan en dispositivos XLA, como las TPUs de Cloud y las GPUs. Stable Diffusion es un modelo de conversión de texto a imagen latente que genera imágenes fotorrealistas a partir de cualquier entrada de texto.
Debes instalar una rama de Git específica para ejecutar MaxDiffusion, tal como se muestra en la siguiente secuencia de comandos de entrenamiento.
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 && pip install -r requirements.txt && pip install . && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR} && python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash \ run_name=sdxl-ddp-v6e
Exporta las siguientes variables:
export CLUSTER_NAME=CLUSTER_NAME export ACCELERATOR_TYPE=ACCELERATOR_TYPE export NUM_SLICES=NUM_SLICES export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT
Descripciones de las variables de entorno
Variable Descripción CLUSTER_NAME
El nombre de tu clúster XPK. ACCELERATOR_TYPE
El tipo de acelerador especifica la versión y el tamaño de la TPU de Cloud que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU. NUM_SLICES
Número de porciones de TPU. YOUR_MODEL_SCRIPT
El script del modelo que se va a ejecutar como comando de entrenamiento. Ejecuta el modelo con la secuencia de comandos que has creado en el paso anterior. Debes especificar la marca
--base-docker-image
para usar la imagen base MaxText o la marca--docker-image
y la imagen que quieras usar.Puedes añadir las siguientes marcas opcionales:
- Para habilitar el registro de depuración, incluye la marca
--enable-debug-logs
. Para obtener más información, consulta Depurar JAX en MaxText. - Puedes crear un experimento de Vertex AI para subir datos a Vertex AI TensorBoard incluyendo la marca
--use-vertex-tensorboard
. Para obtener más información, consulta Monitorizar JAX en MaxText con Vertex AI.
python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command="${YOUR_MODEL_SCRIPT}"
El resultado incluye un enlace para seguir tu carga de trabajo. Abre el enlace y haz clic en la pestaña Registros para monitorizar tu carga de trabajo en tiempo real.
- Para habilitar el registro de depuración, incluye la marca
Depurar JAX en MaxText
Usa comandos XPK complementarios para diagnosticar por qué no se ejecuta el clúster o la carga de trabajo:
- Lista de cargas de trabajo de XPK
- Inspector de XPK
- Habilita el registro detallado en los registros de tu carga de trabajo con la marca
--enable-debug-logs
cuando crees la carga de trabajo XPK.
Monitorizar JAX en MaxText con Vertex AI
Para usar TensorBoard, tu cuenta de usuario debe tener el rol aiplatform.user
. Google Cloud Ejecuta el siguiente comando para conceder este rol:
gcloud projects add-iam-policy-binding your-project-id \ --member='user:your-email' \ --role='roles/aiplatform.user'
Consulta datos escalares y de perfil a través de TensorBoard gestionado de Vertex AI.
Aumenta las solicitudes de gestión de recursos (CRUD) de la zona que estás usando de 600 a 5000. Esto no debería suponer un problema para cargas de trabajo pequeñas que usen menos de 16 VMs.
Instala dependencias como
cloud-accelerator-diagnostics
para Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Crea tu clúster XPK con la marca
--create-vertex-tensorboard
, tal como se describe en el artículo Crear Vertex AI TensorBoard. También puedes ejecutar este comando en clústeres que ya tengas.Crea tu experimento de Vertex AI al ejecutar tu carga de trabajo de XPK con la marca
--use-vertex-tensorboard
y la marca opcional--experiment-name
. Para ver la lista completa de pasos, consulta Crear un experimento de Vertex AI para subir datos a Vertex AI TensorBoard.
Los registros incluyen un enlace a Vertex AI TensorBoard, similar al siguiente:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
También puedes encontrar el enlace de Vertex AI TensorBoard en la Google Cloud consola. Ve a Vertex AI Experiments en la Google Cloud consola. Selecciona la región correspondiente en el menú desplegable.
El directorio de TensorBoard también se escribe en el segmento de Cloud Storage que has especificado con ${BASE_OUTPUT_DIR}
.
Eliminar tu carga de trabajo de XPK
Usa el comando xpk workload delete
para eliminar una o varias cargas de trabajo en función del prefijo o del estado de la tarea. Este comando puede ser útil si has enviado cargas de trabajo XPK que ya no necesitas ejecutar o si tienes trabajos atascados en la cola.
Eliminar el clúster XPK
Usa el comando xpk cluster delete
para eliminar el clúster:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone=${ZONE} --project=${PROJECT_ID}
Resultados de las pruebas de rendimiento de MaxDiffusion
Hemos ejecutado la secuencia de comandos de entrenamiento de MaxDiffusion en una v6e-4, una v6e-16 y dos v6e-16. En la tabla siguiente se muestran los rendimientos medidos.
v6e-4 | v6e-16 | Dos v6e-16 | |
---|---|---|---|
Pasos de formación | 0,069 | 0,073 | 0,13 |
Tamaño de lote global | 8 | 32 | 64 |
Rendimiento (ejemplos/s) | 115,9 | 438,4 EGP | 492,3 |
Entrenar modelos Llama con PyTorch/XLA en TPU v6e de Cloud
En esta sección se describe cómo entrenar modelos Llama con PyTorch/XLA en Cloud TPU v6e usando el conjunto de datos WikiText.
Acceder a Hugging Face y al modelo Llama 3
Para este ejemplo, necesitas un token de acceso de usuario de Hugging Face. Para obtener información sobre cómo crear tokens de acceso de usuario, consulta la documentación de Hugging Face sobre tokens de acceso de usuario.
También necesitas permiso para acceder al modelo Llama 3 8B en Hugging Face. Para obtener acceso, ve al modelo Meta-Llama-3-8B en Hugging Face y solicita acceso.
Crear una VM de TPU de Cloud
Crea una TPU de Cloud v6e con 8 chips para este ejemplo.
Configura las variables de entorno:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east1-d export ACCELERATOR_TYPE=v6e-8 export RUNTIME_VERSION=v2-alpha-tpuv6e
Descripciones de las variables de entorno
Variable Descripción PROJECT_ID
El ID de tu proyecto Google Cloud . Usa un proyecto que ya tengas o crea uno. TPU_NAME
El nombre de la TPU. ZONE
La zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU. ACCELERATOR_TYPE
El tipo de acelerador especifica la versión y el tamaño de la TPU de Cloud que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU. RUNTIME_VERSION
La versión de software de la TPU de Cloud. Crea una VM de TPU de Cloud:
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \ --accelerator-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Instalación
Instala la bifurcación pytorch-tpu/transformers
de las transformaciones de Hugging Face y las dependencias. Este ejemplo se ha probado con las siguientes versiones de dependencia:
torch
: compatible con la versión 2.5.0torch_xla[tpu]
: compatible con la versión 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'
Configurar archivos de configuración de modelos
El comando de entrenamiento de la siguiente sección, Ejecutar el modelo, usa dos archivos de configuración JSON para definir los parámetros del modelo y la configuración de Fully Sharded Data Parallel (FSDP). El particionado de FSDP te permite usar un tamaño de lote mayor durante el entrenamiento al particionar las ponderaciones de tu modelo en varias TPUs. Cuando se entrena con modelos más pequeños, puede ser suficiente con usar el paralelismo de datos y replicar los pesos en cada dispositivo. Para obtener más información sobre cómo fragmentar tensores entre dispositivos en PyTorch/XLA, consulta la guía de usuario de SPMD de PyTorch/XLA.
Crea el archivo de configuración de parámetros del modelo. A continuación, se muestra la configuración de los parámetros del modelo Llama-3-8B. En el caso de otros modelos, busca el archivo de configuración en Hugging Face. Por ejemplo, consulta la configuración de Llama-2-7B.
cat > llama-config.json << EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
Crea el archivo de configuración de FSDP:
cat > fsdp-config.json << EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
Para obtener más información sobre FSDP, consulta Paralelismo de datos totalmente fragmentado con SPMD .
Sube los archivos de configuración a tus VMs de Cloud TPU con el siguiente comando:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
Ejecutar el modelo
Con los archivos de configuración que has creado en la sección anterior, ejecuta la secuencia de comandos run_clm.py
para entrenar el modelo Llama-3-8B con el conjunto de datos WikiText. El script de entrenamiento tarda aproximadamente 10 minutos en ejecutarse en una TPU de Cloud v6e-8.
Inicia sesión en Hugging Face en tu TPU de Cloud con el siguiente comando:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Ejecuta el entrenamiento del modelo:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Solución de problemas de PyTorch/XLA
Si has definido las variables opcionales para depurar en la sección anterior, el perfil del modelo se almacenará en la ubicación especificada por la variable PROFILE_LOGDIR
. Puedes extraer el archivo xplane.pb
almacenado en esta ubicación y usar tensorboard
para ver los perfiles en tu navegador siguiendo las instrucciones de TensorBoard.
Si PyTorch/XLA no funciona como esperabas, consulta la guía de solución de problemas, que incluye sugerencias para depurar, perfilar y optimizar tu modelo.