Entrenar un modelo con TPU v6e

En este documento se explica cómo entrenar modelos en Cloud TPU v6e (también llamada Trillium), lo que incluye la configuración del entorno, la optimización del rendimiento y ejemplos prácticos de entrenamiento con JAX y PyTorch/XLA.

La TPU v6e, también llamada Trillium, es la TPU de sexta generación de Google. En todas las plataformas técnicas, como la API y los registros, y a lo largo de este documento, Trillium se denominará v6e. Con 256 chips por Pod, la arquitectura de la versión 6e de TPU comparte muchas similitudes con la versión 5e. La TPU v6e está optimizada para el entrenamiento, el ajuste y el servicio de transformadores, modelos de texto a imagen y redes neuronales convolucionales (CNNs). Para obtener más información sobre la arquitectura y las configuraciones del sistema de TPU v6e, consulta TPU v6e.

Para obtener información sobre cómo ejecutar inferencias en la versión 6e de TPU de Cloud, consulta los siguientes tutoriales:

Antes de empezar

Antes de empezar, debes hacer lo siguiente:

  • Crear una Google Cloud cuenta y un proyecto con la facturación habilitada
  • Instalar los componentes alfa de Google Cloud CLI
  • Habilitar la API de Cloud TPU
  • Crear un agente de servicio de TPU de Cloud
  • Crear una cuenta de servicio de TPU de Cloud y conceder permisos

Para obtener más información, consulta Configurar el entorno de TPU de Cloud.

Verificar la cuota y los permisos

Comprueba que tu proyecto tenga las siguientes cuotas:

Si usas GKE con XPK, necesitas permisos adicionales en la Google Cloud consola. Para obtener más información, consulta Permisos necesarios en la consolaGoogle Cloud .

Provisionar TPUs

Puedes aprovisionar y gestionar TPUs v6e con los siguientes métodos:

  • GKE puedes usar GKE para aprovisionar y gestionar las TPUs como un grupo de aceleradores para tus cargas de trabajo de aprendizaje automático en contenedores. Para obtener más información, consulta el artículo Acerca de las TPUs en GKE.
  • GKE y XPK: XPK es una herramienta de línea de comandos que simplifica la creación de clústeres y la ejecución de cargas de trabajo en GKE. Está diseñado para que los profesionales del aprendizaje automático puedan aprovisionar TPUs y ejecutar trabajos de entrenamiento sin necesidad de tener un gran dominio de Kubernetes. Para obtener más información, consulta el repositorio de GitHub de XPK.
  • Recursos en cola de TPU de Cloud: los recursos en cola te permiten solicitar capacidad de TPU que se aprovisiona cuando está disponible. Es ideal para tareas por lotes y cargas de trabajo tolerantes a fallos que pueden esperar en una cola. Puedes especificar un periodo para tu solicitud. Para obtener más información, consulta Gestionar recursos en cola.

Provisionar TPUs de Cloud v6e con GKE y XPK

Si usas comandos de GKE con v6e, puedes usar comandos de Kubernetes o XPK para aprovisionar Cloud TPUs y entrenar o servir modelos. Consulta Planificar las TPUs de Cloud en GKE para saber cómo planificar las configuraciones de TPUs de Cloud en clústeres de GKE. En las siguientes secciones se proporcionan comandos para crear un clúster XPK con compatibilidad con una sola NIC y con varias NIC.

Crear un clúster XPK con compatibilidad con una sola NIC

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --create-vertex-tensorboard

Descripciones de marcas de comandos

Variable Descripción
CLUSTER_NAME Nombre asignado por el usuario al clúster XPK.
PROJECT_ID Google Cloud Nombre del proyecto. Usa un proyecto que ya tengas o crea uno. Para obtener más información, consulta el artículo Configurar un Google Cloud proyecto.
ZONE Consulta el documento Regiones y zonas de TPU de Cloud para ver las zonas admitidas.
TPU_TYPE Consulta Tipos de aceleradores.
NUM_SLICES El número de porciones que quieres crear
CLUSTER_ARGUMENTS La red y la subred que se van a usar.

Por ejemplo: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES Número de porciones que se van a crear.
NETWORK_NAME Nombre de una red secundaria que se va a usar.
NETWORK_FW_NAME Nombre de un cortafuegos de red secundario que se va a usar.

Crear un clúster de XPK con compatibilidad con varias NICs

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_1} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
   --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
   --create-vertex-tensorboard

Descripciones de marcas de comandos

Variable Descripción
CLUSTER_NAME Nombre asignado por el usuario al clúster XPK.
PROJECT_ID Google Cloud Nombre del proyecto. Usa un proyecto que ya tengas o crea uno. Para obtener más información, consulta el artículo Configurar un Google Cloud proyecto.
ZONE Consulta el documento Regiones y zonas de TPU de Cloud para ver las zonas admitidas.
TPU_TYPE Consulta Tipos de aceleradores.
NUM_SLICES El número de porciones que quieres crear
CLUSTER_ARGUMENTS La red y la subred que se van a usar.

Por ejemplo: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS Red de nodos adicional que se va a utilizar.

Por ejemplo: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES Número de segmentos que se van a crear (solo es necesario para Multislice).
NETWORK_NAME Nombre de una red secundaria que se va a usar.
NETWORK_FW_NAME Nombre de un cortafuegos de red secundario que se va a usar.

Configurar JAX o PyTorch

En los siguientes recursos se muestra cómo configurar JAX o PyTorch en tu TPU de Cloud, en función del método de aprovisionamiento y gestión que utilices:

Para configurar y ejecutar XPK con MaxText, consulta Ejecutar MaxText a gran escala con XPK .

Optimizar el rendimiento de la red

En esta sección se describe cómo optimizar el rendimiento de su red configurando la unidad de transmisión máxima (MTU), usando varias NICs en entornos Multislice y mejorando los ajustes de TCP.

Configurar la MTU

Para obtener el mejor rendimiento de la red, utiliza una red con una MTU (unidad de transmisión máxima) de 8896.

De forma predeterminada, una nube privada virtual (VPC) solo proporciona una MTU de 1460 bytes, lo que ofrece un rendimiento de red no óptimo. Puedes definir el valor de MTU de una red VPC entre 1300 y 8896 bytes (ambos incluidos). Los tamaños de MTU personalizados habituales son 1500 bytes (Ethernet estándar) u 8896 bytes (el máximo posible). Para obtener más información, consulta Tamaños de MTU válidos para redes VPC.

Para obtener más información sobre cómo cambiar el ajuste de MTU de una red predeterminada o de una red que ya tengas, consulta Cambiar el ajuste de MTU de una red de VPC.

En el siguiente ejemplo, se crea una red con un MTU de 8896 y una regla de cortafuegos correspondiente que permite el tráfico TCP, ICMP y UDP en la red.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
    --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
    --allow tcp,icmp,udp --project=${PROJECT_ID}

Sustituye your-resource-name por un nombre base para la red y el cortafuegos.

Usar la opción de varias NICs para Multislice

Si usas un entorno Multislice, define las siguientes variables de entorno, que son obligatorias para una subred secundaria:

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Usa los siguientes comandos para crear un enrutamiento de IP personalizado para la red y la subred.

  1. Crea la red secundaria.

    gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
    --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
    
  2. Crea una subred para la red secundaria.

    gcloud compute networks subnets create ${SUBNET_NAME_2} \
    --network=${NETWORK_NAME_2} \
    --range=10.10.0.0/18 --region=${REGION} \
    --project=${PROJECT_ID}
    
  3. Crea una regla de cortafuegos para permitir el tráfico en la nueva subred.

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
    --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
    
  4. Crea un router de Cloud Router para la red secundaria.

    gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_2} \
    --region=${REGION}
    
  5. Crea una configuración de NAT para el router de Cloud Router.

    gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
    

Después de crear una partición de red múltiple, puedes validar que se están usando ambas tarjetas de interfaz de red (NICs) configurando un clúster XPK y añadiendo la marca --command ifconfig al comando de creación de carga de trabajo XPK.

  1. Usa el siguiente comando workload create para mostrar el resultado del comando ifconfig en los registros de la consola y comprueba que tanto eth0 como eth1 tienen el valor 8896 en MTU. Google Cloud

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --command "ifconfig"

    Si quieres habilitar los registros de depuración o usar Vertex AI TensorBoard, añade los siguientes argumentos opcionales al comando:

    --enable-debug-logs \
    --use-vertex-tensorboard
  2. Comprueba que tanto eth0 como eth1 tengan el MTU configurado en 8896. Para ello,consulta la salida de la carga de trabajo de XPK en los registros de la consola Google Cloud .

Mejorar la configuración de TCP

Si has aprovisionado tus Cloud TPUs mediante recursos en cola, puedes ejecutar el siguiente comando para mejorar el rendimiento de la red aumentando los límites del búfer de recepción de TCP.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
    --project "${PROJECT_ID}" \
    --zone "${ZONE}" \
    --node=all \
    --worker=all \
    --command='
    sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

Optimizar el rendimiento de la asignación de memoria

La biblioteca tcmalloc se usa de forma predeterminada en las VMs de TPU de Cloud para mejorar el rendimiento de los modelos con asignaciones de memoria frecuentes y de gran tamaño. Esto se configura mediante la variable de entorno LD_PRELOAD.

Sin embargo, en algunas cargas de trabajo (por ejemplo, DLRM con asignaciones de tablas de inserciones muy grandes), tcmalloc puede provocar una ralentización. En estos casos, puedes volver a la función malloc estándar anulando la variable LD_PRELOAD en tu sesión de shell antes de ejecutar la secuencia de comandos de entrenamiento:

unset LD_PRELOAD

Usar SkyPilot

Puedes usar Cloud TPU v6e con SkyPilot. SkyPilot es un framework de código abierto que simplifica el proceso de ejecutar, gestionar y escalar cargas de trabajo de IA. Puedes añadir información sobre la ubicación y los precios relacionados con v6e a SkyPilot. Para obtener más información, consulta el ejemplo de TPU v6e de SkyPilot.

Ejemplos de entrenamiento

En las siguientes secciones se ofrecen ejemplos para entrenar modelos de MaxText, MaxDiffusion y PyTorch en Cloud TPU v6e.

Estos ejemplos se han probado con las siguientes versiones de software:

  • Python 3.10 o versiones posteriores
  • Versiones de software nocturnas:
    • JAX nocturno 0.4.32.dev20240912
    • LibTPU nocturno 0.1.dev20240912+nightly
  • Versiones de software estables:
    • JAX + JAX Lib de la versión 0.4.37

Entrenar MaxText y MaxDiffusion en TPU v6e de Cloud

En las siguientes secciones se describe el ciclo de vida del entrenamiento de los modelos MaxText y MaxDiffusion.

En general, los pasos que debe seguir son los siguientes:

  1. Crea la imagen base de la carga de trabajo.
  2. Ejecuta tu carga de trabajo con XPK.
    1. Crea el comando de entrenamiento para la carga de trabajo.
    2. Despliega la carga de trabajo.
  3. Sigue la carga de trabajo y consulta las métricas.
  4. Elimina la carga de trabajo XPK si no es necesaria.
  5. Elimina el clúster XPK cuando ya no lo necesites.

Crear imagen base

Instala MaxText o MaxDiffusion y crea la imagen de Docker:

  1. Clona el repositorio que quieras usar y cambia al directorio del repositorio:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. Configura Docker para usar la CLI de Google Cloud:

    gcloud auth configure-docker
    
  3. Crea la imagen de Docker con el siguiente comando o con una imagen de IA de JAX. Para obtener más información sobre las imágenes de JAX AI, consulta Imágenes de JAX AI.

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. Define el ID de tu proyecto en la configuración activa de gcloud CLI:

    gcloud config set project ${PROJECT_ID}
    
  5. Si vas a iniciar la carga de trabajo desde una máquina que no tiene la imagen compilada localmente, sube la imagen.

    1. Define la variable de entorno CLOUD_IMAGE_NAME:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. Sube la imagen:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

Ejecutar una carga de trabajo con XPK

  1. Define las siguientes variables de entorno si no usas los valores predeterminados definidos por MaxText o MaxDiffusion:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Crea la secuencia de comandos del modelo. Esta secuencia de comandos se copiará como comando de entrenamiento en un paso posterior.

    No ejecutes la secuencia de comandos del modelo todavía.

    MaxText

    MaxText es un LLM de código abierto de alto rendimiento y alta escalabilidad escrito en Python y JAX puros, y orientado a Google Cloud TPUs y GPUs para el entrenamiento y la inferencia.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma es una familia de LLMs de código abierto desarrollada por Google DeepMind a partir de la investigación y la tecnología de Gemini.

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral es un modelo de IA de vanguardia desarrollado por Mistral AI que utiliza una arquitectura de Mixture-of-Experts (MoE) dispersa.

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama es una familia de LLMs de pesos abiertos desarrollada por Meta.

    Para ver un ejemplo de cómo ejecutar Llama 3 en PyTorch, consulta los modelos torch_xla del repositorio torchprime.

    MaxDiffusion

    MaxDiffusion es una colección de implementaciones de referencia de varios modelos de difusión latente escritos en Python y JAX puros que se ejecutan en dispositivos XLA, como las TPUs de Cloud y las GPUs. Stable Diffusion es un modelo de conversión de texto a imagen latente que genera imágenes fotorrealistas a partir de cualquier entrada de texto.

    Debes instalar una rama de Git específica para ejecutar MaxDiffusion, tal como se muestra en la siguiente secuencia de comandos de entrenamiento.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    && pip install -r requirements.txt && pip install .
    && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR}
    && python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR} \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash \
        run_name=sdxl-ddp-v6e
    
  3. Exporta las siguientes variables:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Descripciones de las variables de entorno

    Variable Descripción
    CLUSTER_NAME El nombre de tu clúster XPK.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de la TPU de Cloud que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    NUM_SLICES Número de porciones de TPU.
    YOUR_MODEL_SCRIPT El script del modelo que se va a ejecutar como comando de entrenamiento.
  4. Ejecuta el modelo con la secuencia de comandos que has creado en el paso anterior. Debes especificar la marca --base-docker-image para usar la imagen base MaxText o la marca --docker-image y la imagen que quieras usar.

    Puedes añadir las siguientes marcas opcionales:

    • Para habilitar el registro de depuración, incluye la marca --enable-debug-logs. Para obtener más información, consulta Depurar JAX en MaxText.
    • Puedes crear un experimento de Vertex AI para subir datos a Vertex AI TensorBoard incluyendo la marca --use-vertex-tensorboard. Para obtener más información, consulta Monitorizar JAX en MaxText con Vertex AI.
    python3 xpk.py workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --command="${YOUR_MODEL_SCRIPT}"

    El resultado incluye un enlace para seguir tu carga de trabajo. Abre el enlace y haz clic en la pestaña Registros para monitorizar tu carga de trabajo en tiempo real.

Depurar JAX en MaxText

Usa comandos XPK complementarios para diagnosticar por qué no se ejecuta el clúster o la carga de trabajo:

Monitorizar JAX en MaxText con Vertex AI

Para usar TensorBoard, tu cuenta de usuario debe tener el rol aiplatform.user . Google Cloud Ejecuta el siguiente comando para conceder este rol:

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

Consulta datos escalares y de perfil a través de TensorBoard gestionado de Vertex AI.

  1. Aumenta las solicitudes de gestión de recursos (CRUD) de la zona que estás usando de 600 a 5000. Esto no debería suponer un problema para cargas de trabajo pequeñas que usen menos de 16 VMs.

  2. Instala dependencias como cloud-accelerator-diagnostics para Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Crea tu clúster XPK con la marca --create-vertex-tensorboard, tal como se describe en el artículo Crear Vertex AI TensorBoard. También puedes ejecutar este comando en clústeres que ya tengas.

  4. Crea tu experimento de Vertex AI al ejecutar tu carga de trabajo de XPK con la marca --use-vertex-tensorboard y la marca opcional --experiment-name. Para ver la lista completa de pasos, consulta Crear un experimento de Vertex AI para subir datos a Vertex AI TensorBoard.

Los registros incluyen un enlace a Vertex AI TensorBoard, similar al siguiente:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

También puedes encontrar el enlace de Vertex AI TensorBoard en la Google Cloud consola. Ve a Vertex AI Experiments en la Google Cloud consola. Selecciona la región correspondiente en el menú desplegable.

El directorio de TensorBoard también se escribe en el segmento de Cloud Storage que has especificado con ${BASE_OUTPUT_DIR}.

Eliminar tu carga de trabajo de XPK

Usa el comando xpk workload delete para eliminar una o varias cargas de trabajo en función del prefijo o del estado de la tarea. Este comando puede ser útil si has enviado cargas de trabajo XPK que ya no necesitas ejecutar o si tienes trabajos atascados en la cola.

Eliminar el clúster XPK

Usa el comando xpk cluster delete para eliminar el clúster:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
    --zone=${ZONE} --project=${PROJECT_ID}

Resultados de las pruebas de rendimiento de MaxDiffusion

Hemos ejecutado la secuencia de comandos de entrenamiento de MaxDiffusion en una v6e-4, una v6e-16 y dos v6e-16. En la tabla siguiente se muestran los rendimientos medidos.

v6e-4 v6e-16 Dos v6e-16
Pasos de formación 0,069 0,073 0,13
Tamaño de lote global 8 32 64
Rendimiento (ejemplos/s) 115,9 438,4 EGP 492,3

Entrenar modelos Llama con PyTorch/XLA en TPU v6e de Cloud

En esta sección se describe cómo entrenar modelos Llama con PyTorch/XLA en Cloud TPU v6e usando el conjunto de datos WikiText.

Acceder a Hugging Face y al modelo Llama 3

Para este ejemplo, necesitas un token de acceso de usuario de Hugging Face. Para obtener información sobre cómo crear tokens de acceso de usuario, consulta la documentación de Hugging Face sobre tokens de acceso de usuario.

También necesitas permiso para acceder al modelo Llama 3 8B en Hugging Face. Para obtener acceso, ve al modelo Meta-Llama-3-8B en Hugging Face y solicita acceso.

Crear una VM de TPU de Cloud

Crea una TPU de Cloud v6e con 8 chips para este ejemplo.

  1. Configura las variables de entorno:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east1-d
    export ACCELERATOR_TYPE=v6e-8
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    Descripciones de las variables de entorno

    Variable Descripción
    PROJECT_ID El ID de tu proyecto Google Cloud . Usa un proyecto que ya tengas o crea uno.
    TPU_NAME El nombre de la TPU.
    ZONE La zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de la TPU de Cloud que quieres crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    RUNTIME_VERSION La versión de software de la TPU de Cloud.

  2. Crea una VM de TPU de Cloud:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

Instalación

Instala la bifurcación pytorch-tpu/transformers de las transformaciones de Hugging Face y las dependencias. Este ejemplo se ha probado con las siguientes versiones de dependencia:

  • torch: compatible con la versión 2.5.0
  • torch_xla[tpu]: compatible con la versión 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

Configurar archivos de configuración de modelos

El comando de entrenamiento de la siguiente sección, Ejecutar el modelo, usa dos archivos de configuración JSON para definir los parámetros del modelo y la configuración de Fully Sharded Data Parallel (FSDP). El particionado de FSDP te permite usar un tamaño de lote mayor durante el entrenamiento al particionar las ponderaciones de tu modelo en varias TPUs. Cuando se entrena con modelos más pequeños, puede ser suficiente con usar el paralelismo de datos y replicar los pesos en cada dispositivo. Para obtener más información sobre cómo fragmentar tensores entre dispositivos en PyTorch/XLA, consulta la guía de usuario de SPMD de PyTorch/XLA.

  1. Crea el archivo de configuración de parámetros del modelo. A continuación, se muestra la configuración de los parámetros del modelo Llama-3-8B. En el caso de otros modelos, busca el archivo de configuración en Hugging Face. Por ejemplo, consulta la configuración de Llama-2-7B.

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. Crea el archivo de configuración de FSDP:

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Para obtener más información sobre FSDP, consulta Paralelismo de datos totalmente fragmentado con SPMD .

  3. Sube los archivos de configuración a tus VMs de Cloud TPU con el siguiente comando:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

Ejecutar el modelo

Con los archivos de configuración que has creado en la sección anterior, ejecuta la secuencia de comandos run_clm.py para entrenar el modelo Llama-3-8B con el conjunto de datos WikiText. El script de entrenamiento tarda aproximadamente 10 minutos en ejecutarse en una TPU de Cloud v6e-8.

  1. Inicia sesión en Hugging Face en tu TPU de Cloud con el siguiente comando:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Ejecuta el entrenamiento del modelo:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

Solución de problemas de PyTorch/XLA

Si has definido las variables opcionales para depurar en la sección anterior, el perfil del modelo se almacenará en la ubicación especificada por la variable PROFILE_LOGDIR. Puedes extraer el archivo xplane.pb almacenado en esta ubicación y usar tensorboard para ver los perfiles en tu navegador siguiendo las instrucciones de TensorBoard.

Si PyTorch/XLA no funciona como esperabas, consulta la guía de solución de problemas, que incluye sugerencias para depurar, perfilar y optimizar tu modelo.