Entrena un modelo con la TPU v6e
En este documento, se explica cómo entrenar modelos en Cloud TPU v6e (también llamada Trillium), y se abarca la configuración del entorno, la optimización del rendimiento y ejemplos prácticos de entrenamiento con JAX y PyTorch/XLA.
La TPU v6e, también llamada Trillium, es la 6ª generación de TPUs de Google. En todas las plataformas técnicas, como la API y los registros, y en todo este documento, se hará referencia a Trillium como v6e. Con 256 chips por Pod, la arquitectura de la TPU v6e comparte muchas similitudes con la v5e. La TPU v6e está optimizada para el entrenamiento, el ajuste y la publicación de modelos de transformadores, de texto a imagen y de redes neuronales convolucionales (CNN). Para obtener más información sobre la arquitectura y las configuraciones del sistema de la TPU v6e, consulta TPU v6e.
Para obtener información sobre cómo ejecutar la inferencia en Cloud TPU v6e, consulta los siguientes instructivos:
- Inferencia de JetStream MaxText en v6e
- Inferencia de JetStream PyTorch en v6e
- Inferencia de MaxDiffusion en v6e
- Inferencia de vLLM en v6e
- Realiza la inferencia multihost con Pathways
Antes de comenzar
Antes de comenzar, debes hacer lo siguiente:
- Crea una Google Cloud cuenta y un proyecto con la facturación habilitada
- Instala los componentes alfa de Google Cloud CLI
- Habilita la API de Cloud TPU
- Crea un agente de servicio de Cloud TPU
- Crea una cuenta de servicio de Cloud TPU y otorga permisos
Para obtener más información, consulta Configura el entorno de Cloud TPU.
Verifica la cuota y los permisos
Verifica que tu proyecto tenga las siguientes cuotas:
- Cuota de TPU v6e interrumpible o bajo demanda
- Cuota de direcciones IP
Cuota para Hyperdisk Balanced y para cualquier otro tipo de disco que desees usar
Si usas GKE con XPK, necesitas permisos adicionales en la consola de Google Cloud . Para obtener más información, consulta Permisos necesarios en la consola deGoogle Cloud .
Aprovisiona TPU
Puedes aprovisionar y administrar las TPU v6e con los siguientes métodos:
- GKE: Puedes usar GKE para aprovisionar y administrar TPUs como un grupo de aceleradores para tus cargas de trabajo de aprendizaje automático alojadas en contenedores. Para obtener más información, consulta Acerca de las TPU en GKE.
- GKE y XPK: XPK es una herramienta de línea de comandos que simplifica la creación de clústeres y la ejecución de cargas de trabajo en GKE. Está diseñado para que los profesionales del AA aprovisionen TPU y ejecuten trabajos de entrenamiento sin necesidad de tener un profundo conocimiento de Kubernetes. Para obtener más información, consulta el repositorio de GitHub de XPK.
- Recursos en cola de Cloud TPU: Los recursos en cola te permiten solicitar capacidad de TPU que se aprovisiona cuando está disponible. Es ideal para trabajos por lotes y cargas de trabajo tolerantes a errores que pueden esperar en una cola. Puedes especificar un período para tu solicitud. Para obtener más información, consulta Administra recursos en cola.
Aprovisiona Cloud TPU v6e con GKE y XPK
Si usas comandos de GKE con v6e, puedes usar comandos de Kubernetes o XPK para aprovisionar Cloud TPUs y entrenar o entregar modelos. Consulta Planifica el uso de Cloud TPUs en GKE para obtener información sobre cómo planificar tus configuraciones de Cloud TPU en clústeres de GKE. En las siguientes secciones, se proporcionan comandos para crear un clúster de XPK con compatibilidad para una sola NIC y para varias NIC.
Crea un clúster de XPK con compatibilidad para una sola NIC
export CLUSTER_NAME=xpk-cluster-name export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME=${CLUSTER_NAME}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \ --mtu=8896 \ --project=${PROJECT_ID} \ --subnet-mode=auto \ --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \ --network=${NETWORK_NAME} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --create-vertex-tensorboard
Descripciones de las marcas de comandos
Variable | Descripción |
CLUSTER_NAME | Es el nombre asignado por el usuario para el clúster de XPK. |
ID DEL PROYECTO | Nombre del proyectoGoogle Cloud . Usa un proyecto existente o crea uno nuevo. Para obtener más información, consulta Configura tu proyecto de Google Cloud . |
ZONA | Consulta el documento Regiones y zonas de Cloud TPU para conocer las zonas compatibles. |
TPU_TYPE | Consulta Tipos de aceleradores. |
NUM_SLICES | Cantidad de segmentos que deseas crear |
CLUSTER_ARGUMENTS | La red y la subred que se usarán.
Por ejemplo: |
NUM_SLICES | Es la cantidad de segmentos que se crearán. |
NETWORK_NAME | Es el nombre de una red secundaria que se usará. |
NETWORK_FW_NAME | Es el nombre de un firewall de red secundario que se usará. |
Crea un clúster de XPK con compatibilidad para varias NIC
export CLUSTER_NAME=xpk-cluster-name export REGION=your-region export ZONE=us-east1-d export PROJECT_ID=your-project-id export TPU_TYPE=v6e-256 export NUM_SLICES=2 export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE} export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE} export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE} export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE} export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE} export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \ --network=${NETWORK_NAME_1} \ --range=10.11.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_1} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_1} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.
export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \ --mtu=8896 \ --bgp-routing-mode=regional \ --subnet-mode=custom \ --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 \ --region=${REGION} \ --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} \ --allow tcp,icmp,udp \ --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \ --cluster=${CLUSTER_NAME} \ --cluster-cpu-machine-type=e2-standard-8 \ --num-slices=${NUM_SLICES} \ --tpu-type=${TPU_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --on-demand \ --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \ --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \ --create-vertex-tensorboard
Descripciones de las marcas de comandos
Variable | Descripción |
CLUSTER_NAME | Es el nombre asignado por el usuario para el clúster de XPK. |
ID DEL PROYECTO | Nombre del proyectoGoogle Cloud . Usa un proyecto existente o crea uno nuevo. Para obtener más información, consulta Configura tu proyecto de Google Cloud . |
ZONA | Consulta el documento Regiones y zonas de Cloud TPU para conocer las zonas compatibles. |
TPU_TYPE | Consulta Tipos de aceleradores. |
NUM_SLICES | Cantidad de segmentos que deseas crear |
CLUSTER_ARGUMENTS | La red y la subred que se usarán.
Por ejemplo: |
NODE_POOL_ARGUMENTS | Es la red de nodos adicional que se usará.
Por ejemplo: |
NUM_SLICES | Cantidad de segmentos que se crearán (solo es necesario para Multislice). |
NETWORK_NAME | Es el nombre de una red secundaria que se usará. |
NETWORK_FW_NAME | Es el nombre de un firewall de red secundario que se usará. |
Configura JAX o PyTorch
En los siguientes recursos, se muestra cómo configurar JAX o PyTorch en tu Cloud TPU, según el método de aprovisionamiento y administración que uses:
- GKE Autopilot: Prepara tu aplicación para TPU
- GKE Standard: Prepara tus cargas de trabajo
- GKE y XPK: README de XPK
- Cloud TPU de un solo host con JAX: Ejecuta un cálculo en una VM de Cloud TPU con JAX
- Cloud TPU de varios hosts con JAX: Ejecuta código JAX en porciones de TPU
- Cloud TPU de host único con PyTorch: Ejecuta un cálculo en una VM de Cloud TPU con PyTorch
- Cloud TPU de varios hosts con PyTorch: Ejecuta código de PyTorch en segmentos de TPU
Para configurar y ejecutar XPK con MaxText, consulta Cómo ejecutar MaxText a gran escala con XPK .
Optimiza el rendimiento de la red
En esta sección, se describe cómo optimizar el rendimiento de tu red configurando la unidad de transmisión máxima (MTU), usando varias NIC para entornos de Multislice y mejorando la configuración de TCP.
Configura la MTU
Para obtener el mejor rendimiento de la red, usa una red con una MTU (unidad de transmisión máxima) de 8,896.
De forma predeterminada, una nube privada virtual (VPC) solo proporciona una MTU de 1,460 bytes, lo que genera un rendimiento de red subóptimo. Puedes configurar la MTU de una red de VPC en cualquier valor entre 1,300 bytes y 8,896 bytes (inclusive). Los tamaños de MTU personalizados comunes son 1,500 bytes (Ethernet estándar) o 8,896 bytes (el máximo posible). Para obtener más información, consulta Tamaños válidos de MTU de la red de VPC.
Para obtener más información sobre cómo cambiar el parámetro de configuración de MTU de una red existente o predeterminada, consulta Cambia la configuración de MTU de una red de VPC.
En el siguiente ejemplo,se crea una red con un MTU de 8, 896 y una regla de firewall correspondiente que permite el tráfico de TCP, ICMP y UDP dentro de la red.
export RESOURCE_NAME=your-resource-name export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \ --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \ --allow tcp,icmp,udp --project=${PROJECT_ID}
Reemplaza your-resource-name por un nombre base para la red y el firewall.
Usa la opción de varias NIC para Multislice
Si usas un entorno de Multislice, configura las siguientes variables de entorno, que son obligatorias para una subred secundaria:
export NETWORK_NAME_2=${RESOURCE_NAME} export SUBNET_NAME_2=${RESOURCE_NAME} export FIREWALL_RULE_NAME=${RESOURCE_NAME} export ROUTER_NAME=${RESOURCE_NAME}-network-2 export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2 export REGION=your-region
Usa los siguientes comandos para crear un enrutamiento de IP personalizado para la red y la subred.
Crea la red secundaria.
gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \ --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
Crea una subred para la red secundaria.
gcloud compute networks subnets create ${SUBNET_NAME_2} \ --network=${NETWORK_NAME_2} \ --range=10.10.0.0/18 --region=${REGION} \ --project=${PROJECT_ID}
Crea una regla de firewall para permitir el tráfico dentro de la subred nueva.
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \ --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \ --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
Crea un Cloud Router para la red secundaria.
gcloud compute routers create ${ROUTER_NAME} \ --project=${PROJECT_ID} \ --network=${NETWORK_NAME_2} \ --region=${REGION}
Crea una configuración de NAT para el Cloud Router.
gcloud compute routers nats create ${NAT_CONFIG} \ --router=${ROUTER_NAME} \ --region=${REGION} \ --auto-allocate-nat-external-ips \ --nat-all-subnet-ip-ranges \ --project=${PROJECT_ID} \ --enable-logging
Después de crear un segmento de varias redes, puedes validar que se usen ambas tarjetas de interfaz de red (NIC) configurando un clúster de XPK y agregando la marca --command ifconfig
al comando de creación de la carga de trabajo de XPK.
Usa el siguiente comando
workload create
para mostrar el resultado del comandoifconfig
en los registros de la consola de Google Cloud y verifica que eth0 y eth1 tengan el MTU establecido en 8,896.python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command "ifconfig"
Si deseas habilitar los registros de depuración o usar Vertex AI TensorBoard, agrega los siguientes argumentos opcionales al comando:
--enable-debug-logs \ --use-vertex-tensorboard
Verifica que tanto eth0 como eth1 tengan la MTU establecida en 8,896. Para ello, revisa el resultado de la carga de trabajo de XPK en los registros de la consola de Google Cloud .
Mejora la configuración de TCP
Si aprovisionaste tus Cloud TPU con recursos en cola, puedes ejecutar el siguiente comando para mejorar el rendimiento de la red aumentando los límites del búfer de recepción de TCP.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "${PROJECT_ID}" \ --zone "${ZONE}" \ --node=all \ --worker=all \ --command=' sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'
Optimiza el rendimiento de la asignación de memoria
La biblioteca tcmalloc
se usa de forma predeterminada en las VMs de Cloud TPU para mejorar el rendimiento de los modelos con asignaciones de memoria frecuentes y de gran tamaño. Se configura a través de la variable de entorno LD_PRELOAD
.
Sin embargo, para algunas cargas de trabajo (por ejemplo, DLRM con asignaciones de tablas de incorporación muy grandes), tcmalloc
puede causar una ralentización. En esos casos, puedes volver a la función malloc
estándar anulando la variable LD_PRELOAD
en tu sesión de shell antes de ejecutar la secuencia de comandos de entrenamiento:
unset LD_PRELOAD
Usa SkyPilot
Puedes usar Cloud TPU v6e con SkyPilot. SkyPilot es un framework de código abierto que simplifica el proceso de ejecución, administración y escalamiento de cargas de trabajo de IA. Puedes agregar información sobre la ubicación y los precios relacionados con v6e a SkyPilot. Para obtener más información, consulta el ejemplo de TPU v6e de SkyPilot.
Ejemplos de entrenamiento
En las siguientes secciones, se proporcionan ejemplos para entrenar modelos de MaxText, MaxDiffusion y PyTorch en Cloud TPU v6e.
Estos ejemplos se probaron con las siguientes versiones de software:
- Python
3.10
o una versión posterior - Versiones de software nocturnas:
- JAX nocturno
0.4.32.dev20240912
- LibTPU nocturna
0.1.dev20240912+nightly
- JAX nocturno
- Versiones de software estables:
- JAX y JAX Lib de la versión 0.4.37
Entrena MaxText y MaxDiffusion en Cloud TPU v6e
En las siguientes secciones, se abarca el ciclo de vida del entrenamiento de los modelos MaxText y MaxDiffusion.
En general, los pasos de alto nivel son los siguientes:
- Compila la imagen base de la carga de trabajo.
- Ejecuta tu carga de trabajo con XPK.
- Compila el comando de entrenamiento para la carga de trabajo.
- Implementa la carga de trabajo.
- Sigue la carga de trabajo y consulta las métricas.
- Borra la carga de trabajo de XPK si no es necesaria.
- Borra el clúster de XPK cuando ya no lo necesites.
Compila la imagen base
Instala MaxText o MaxDiffusion y compila la imagen de Docker:
Clona el repositorio que deseas usar y cambia al directorio del repositorio:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
Configura Docker para usar Google Cloud CLI:
gcloud auth configure-docker
Compila la imagen de Docker con el siguiente comando o con una imagen de IA de JAX. Para obtener más información sobre las imágenes de IA de JAX, consulta Imágenes de IA de JAX.
MaxText:
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
MaxDiffusion:
bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
Configura tu ID del proyecto en la configuración activa de gcloud CLI:
gcloud config set project ${PROJECT_ID}
Si inicias la carga de trabajo desde una máquina que no tiene la imagen compilada de forma local, sube la imagen.
Establece la variable de entorno
CLOUD_IMAGE_NAME
:export CLOUD_IMAGE_NAME=${USER}_runner
Sube la imagen:
bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
Ejecuta tu carga de trabajo con XPK
Establece las siguientes variables de entorno si no usas los valores predeterminados establecidos por MaxText o MaxDiffusion:
export BASE_OUTPUT_DIR=gs://YOUR_BUCKET export PER_DEVICE_BATCH_SIZE=2 export NUM_STEPS=30 export MAX_TARGET_LENGTH=8192
Crea tu secuencia de comandos del modelo. Esta secuencia de comandos se copiará como un comando de entrenamiento en un paso posterior.
Aún no ejecutes la secuencia de comandos del modelo.
MaxText
MaxText es un LLM de código abierto de alto rendimiento y altamente escalable escrito en Python y JAX puros, y orientado a Google Cloud TPUs y GPUs para el entrenamiento y la inferencia.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} # attention='dot_product'"
Gemma2
Gemma es una familia de LLMs de código abierto desarrollados por Google DeepMind, basados en la investigación y la tecnología de Gemini.
python3 -m MaxText.train MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral es un modelo de IA de vanguardia desarrollado por Mistral AI que utiliza una arquitectura de mezcla de expertos (MoE) dispersa.
python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama es una familia de LLMs de código abierto desarrollados por Meta.
Para ver un ejemplo de cómo ejecutar Llama3 en PyTorch, consulta los modelos de torch_xla en el repositorio de torchprime.
MaxDiffusion
MaxDiffusion es una colección de implementaciones de referencia de varios modelos de difusión latentes escritos en Python y JAX puros que se ejecutan en dispositivos XLA, incluidas las Cloud TPU y las GPU. Stable Diffusion es un modelo latente de texto a imagen que genera imágenes fotorrealistas a partir de cualquier entrada de texto.
Debes instalar una rama de Git específica para ejecutar MaxDiffusion, como se muestra en la siguiente secuencia de comandos de entrenamiento.
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103 && pip install -r requirements.txt && pip install . && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR} && python src/maxdiffusion/train_sdxl.py \ src/maxdiffusion/configs/base_xl.yml \ revision=refs/pr/95 \ activations_dtype=bfloat16 \ weights_dtype=bfloat16 \ resolution=1024 \ per_device_batch_size=1 \ output_dir=${OUT_DIR} \ jax_cache_dir=${OUT_DIR}/cache_dir/ \ max_train_steps=200 \ attention=flash \ run_name=sdxl-ddp-v6e
Exporta las siguientes variables:
export CLUSTER_NAME=CLUSTER_NAME export ACCELERATOR_TYPE=ACCELERATOR_TYPE export NUM_SLICES=NUM_SLICES export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT
Descripciones de las variables de entorno
Variable Descripción CLUSTER_NAME
Es el nombre de tu clúster de XPK. ACCELERATOR_TYPE
El tipo de acelerador especifica la versión y el tamaño de la Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU. NUM_SLICES
Es la cantidad de porciones de TPU. YOUR_MODEL_SCRIPT
Es la secuencia de comandos del modelo que se ejecutará como un comando de entrenamiento. Ejecuta el modelo con la secuencia de comandos que creaste en el paso anterior. Debes especificar la marca
--base-docker-image
para usar la imagen base de MaxText o la marca--docker-image
y la imagen que deseas usar.Puedes agregar las siguientes marcas opcionales:
- Puedes habilitar el registro de depuración incluyendo la marca
--enable-debug-logs
. Para obtener más información, consulta Cómo depurar JAX en MaxText. - Puedes crear un experimento de Vertex AI para subir datos a Vertex AI TensorBoard incluyendo la marca
--use-vertex-tensorboard
. Para obtener más información, consulta Supervisa JAX en MaxText con Vertex AI.
python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \ --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone=${ZONE} \ --project=${PROJECT_ID} \ --command="${YOUR_MODEL_SCRIPT}"
El resultado incluye un vínculo para seguir tu carga de trabajo. Abre el vínculo y haz clic en la pestaña Registros para hacer un seguimiento de tu carga de trabajo en tiempo real.
- Puedes habilitar el registro de depuración incluyendo la marca
Cómo depurar JAX en MaxText
Usa comandos de XPK complementarios para diagnosticar por qué no se ejecuta el clúster o la carga de trabajo:
- Lista de cargas de trabajo de XPK
- Inspector de XPK
- Habilita el registro detallado en los registros de tu carga de trabajo con la marca
--enable-debug-logs
cuando crees la carga de trabajo de XPK.
Supervisa JAX en MaxText con Vertex AI
Para usar TensorBoard, tu cuenta de usuario Google Cloud debe tener el rol de aiplatform.user
. Ejecuta el siguiente comando para otorgar este rol:
gcloud projects add-iam-policy-binding your-project-id \ --member='user:your-email' \ --role='roles/aiplatform.user'
Visualiza datos de perfil y escalares a través de TensorBoard administrado por Vertex AI.
Aumenta las solicitudes de administración de recursos (CRUD) para la zona que usas de 600 a 5,000. Esto podría no ser un problema para cargas de trabajo pequeñas que usan menos de 16 VMs.
Instala dependencias como
cloud-accelerator-diagnostics
para Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Crea tu clúster de XPK con la marca
--create-vertex-tensorboard
, como se documenta en Crea Vertex AI TensorBoard. También puedes ejecutar este comando en clústeres existentes.Crea tu experimento de Vertex AI cuando ejecutes tu carga de trabajo de XPK con la marca
--use-vertex-tensorboard
y la marca opcional--experiment-name
. Para obtener la lista completa de pasos, consulta Crea un experimento de Vertex AI para subir datos a Vertex AI TensorBoard.
Los registros incluyen un vínculo a un Vertex AI TensorBoard, similar al siguiente:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
También puedes encontrar el vínculo de TensorBoard de Vertex AI en la consola de Google Cloud . Ve a Vertex AI Experiments en la consola de Google Cloud . Selecciona la región adecuada en el menú desplegable.
El directorio de TensorBoard también se escribe en el bucket de Cloud Storage que especificaste con ${BASE_OUTPUT_DIR}
.
Borra tu carga de trabajo de XPK
Usa el comando xpk workload delete
para borrar una o más cargas de trabajo según el prefijo o el estado del trabajo. Este comando puede ser útil si enviaste cargas de trabajo de XPK que ya no necesitan ejecutarse o si tienes trabajos atascados en la cola.
Borra tu clúster de XPK
Usa el comando xpk cluster delete
para borrar el clúster:
python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \ --zone=${ZONE} --project=${PROJECT_ID}
Resultados de las comparativas de MaxDiffusion
Ejecutamos la secuencia de comandos de entrenamiento de MaxDiffusion en una v6e-4, una v6e-16 y dos v6e-16. En la siguiente tabla, se muestran las capacidades de procesamiento medidas.
v6e-4 | v6e-16 | Dos v6e-16 | |
---|---|---|---|
Pasos de entrenamiento | 0.069 | 0.073 | 0.13 |
Tamaño del lote global | 8 | 32 | 64 |
Capacidad de procesamiento (ejemplos/s) | 115.9 | 438.4 | 492.3 |
Entrena modelos de Llama con PyTorch/XLA en Cloud TPU v6e
En esta sección, se describe cómo entrenar modelos de Llama con PyTorch/XLA en Cloud TPU v6e usando el conjunto de datos de WikiText.
Obtén acceso a Hugging Face y al modelo de Llama 3
Para este ejemplo, necesitas un token de acceso de usuario de Hugging Face. Para obtener información sobre cómo crear tokens de acceso de usuario, consulta la documentación de Hugging Face sobre tokens de acceso de usuario.
También necesitas permiso para acceder al modelo Llama-3-8B en Hugging Face. Para obtener acceso, ve al modelo Meta-Llama-3-8B en Hugging Face y solicita acceso.
Crea una VM de Cloud TPU
Crea una Cloud TPU v6e con 8 chips para este ejemplo.
Configure las variables de entorno:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=us-east1-d export ACCELERATOR_TYPE=v6e-8 export RUNTIME_VERSION=v2-alpha-tpuv6e
Descripciones de las variables de entorno
Variable Descripción PROJECT_ID
El ID de tu proyecto Google Cloud . Usa un proyecto existente o crea uno nuevo. TPU_NAME
Nombre de la TPU. ZONE
Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU. ACCELERATOR_TYPE
El tipo de acelerador especifica la versión y el tamaño de la Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU. RUNTIME_VERSION
Versión de software de Cloud TPU. Crea una VM de Cloud TPU:
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \ --accelerator-type=${ACCELERATOR_TYPE} \ --zone=${ZONE} \ --project=${PROJECT_ID}
Instalación
Instala la bifurcación de pytorch-tpu/transformers
de Transformers de Hugging Face y las dependencias. Este ejemplo se probó con las siguientes versiones de dependencias:
torch
: Compatible con la versión 2.5.0torch_xla[tpu]
: Compatible con la versión 2.5.0jax
: 0.4.33jaxlib
: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers sudo pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'
Cómo configurar archivos de configuración del modelo
El comando de entrenamiento de la siguiente sección, Ejecuta el modelo, usa dos archivos de configuración JSON para definir los parámetros del modelo y la configuración de Fully Sharded Data Parallel (FSDP). El sharding de FSDP te permite usar un tamaño de lote más grande durante el entrenamiento, ya que fragmenta los pesos del modelo en varias TPU. Cuando se entrena con modelos más pequeños, podría ser suficiente usar el paralelismo de datos y replicar los pesos en cada dispositivo. Si deseas obtener más información para fragmentar tensores en dispositivos con PyTorch/XLA, consulta la guía del usuario de SPMD de PyTorch/XLA.
Crea el archivo de configuración de parámetros del modelo. A continuación, se muestra la configuración de los parámetros del modelo para Llama-3-8B. Para otros modelos, busca el archivo de configuración en Hugging Face. Por ejemplo, consulta la configuración de Llama-2-7B.
cat > llama-config.json << EOF { "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 } EOF
Crea el archivo de configuración de FSDP:
cat > fsdp-config.json << EOF { "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true } EOF
Para obtener más información sobre FSDP, consulta Fully Sharded Data Parallel using SPMD .
Sube los archivos de configuración a tus VMs de Cloud TPU con el siguiente comando:
gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \ --worker=all \ --project=${PROJECT_ID} \ --zone=${ZONE}
Ejecuta el modelo
Con los archivos de configuración que creaste en la sección anterior, ejecuta la secuencia de comandos run_clm.py
para entrenar el modelo Llama-3-8B en el conjunto de datos de WikiText. La secuencia de comandos de entrenamiento tarda alrededor de 10 minutos en ejecutarse en una Cloud TPU v6e-8.
Accede a Hugging Face en tu Cloud TPU con el siguiente comando:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' pip3 install "huggingface_hub[cli]" huggingface-cli login --token HUGGING_FACE_TOKEN'
Ejecuta el entrenamiento de modelos:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --project=${PROJECT_ID} \ --zone ${ZONE} \ --worker=all \ --command=' export PJRT_DEVICE=TPU export XLA_USE_SPMD=1 export ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: export XLA_IR_DEBUG=1 export XLA_HLO_DEBUG=1 export PROFILE_EPOCH=0 export PROFILE_STEP=3 export PROFILE_DURATION_MS=100000 # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path export PROFILE_LOGDIR=PROFILE_PATH python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 16 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/llama-config.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp-config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20'
Solución de problemas de PyTorch/XLA
Si configuraste las variables opcionales para la depuración en la sección anterior, el perfil del modelo se almacenará en la ubicación especificada por la variable PROFILE_LOGDIR
. Puedes extraer el archivo xplane.pb
almacenado en esta ubicación y usar tensorboard
para ver los perfiles en tu navegador con las instrucciones de TensorBoard.
Si PyTorch/XLA no funciona como se espera, consulta la guía de solución de problemas, que incluye sugerencias para depurar, generar perfiles y optimizar tu modelo.