Entrena un modelo con la TPU v6e

En este documento, se explica cómo entrenar modelos en Cloud TPU v6e (también llamada Trillium), y se abarca la configuración del entorno, la optimización del rendimiento y ejemplos prácticos de entrenamiento con JAX y PyTorch/XLA.

La TPU v6e, también llamada Trillium, es la 6ª generación de TPUs de Google. En todas las plataformas técnicas, como la API y los registros, y en todo este documento, se hará referencia a Trillium como v6e. Con 256 chips por Pod, la arquitectura de la TPU v6e comparte muchas similitudes con la v5e. La TPU v6e está optimizada para el entrenamiento, el ajuste y la publicación de modelos de transformadores, de texto a imagen y de redes neuronales convolucionales (CNN). Para obtener más información sobre la arquitectura y las configuraciones del sistema de la TPU v6e, consulta TPU v6e.

Para obtener información sobre cómo ejecutar la inferencia en Cloud TPU v6e, consulta los siguientes instructivos:

Antes de comenzar

Antes de comenzar, debes hacer lo siguiente:

  • Crea una Google Cloud cuenta y un proyecto con la facturación habilitada
  • Instala los componentes alfa de Google Cloud CLI
  • Habilita la API de Cloud TPU
  • Crea un agente de servicio de Cloud TPU
  • Crea una cuenta de servicio de Cloud TPU y otorga permisos

Para obtener más información, consulta Configura el entorno de Cloud TPU.

Verifica la cuota y los permisos

Verifica que tu proyecto tenga las siguientes cuotas:

Si usas GKE con XPK, necesitas permisos adicionales en la consola de Google Cloud . Para obtener más información, consulta Permisos necesarios en la consola deGoogle Cloud .

Aprovisiona TPU

Puedes aprovisionar y administrar las TPU v6e con los siguientes métodos:

  • GKE: Puedes usar GKE para aprovisionar y administrar TPUs como un grupo de aceleradores para tus cargas de trabajo de aprendizaje automático alojadas en contenedores. Para obtener más información, consulta Acerca de las TPU en GKE.
  • GKE y XPK: XPK es una herramienta de línea de comandos que simplifica la creación de clústeres y la ejecución de cargas de trabajo en GKE. Está diseñado para que los profesionales del AA aprovisionen TPU y ejecuten trabajos de entrenamiento sin necesidad de tener un profundo conocimiento de Kubernetes. Para obtener más información, consulta el repositorio de GitHub de XPK.
  • Recursos en cola de Cloud TPU: Los recursos en cola te permiten solicitar capacidad de TPU que se aprovisiona cuando está disponible. Es ideal para trabajos por lotes y cargas de trabajo tolerantes a errores que pueden esperar en una cola. Puedes especificar un período para tu solicitud. Para obtener más información, consulta Administra recursos en cola.

Aprovisiona Cloud TPU v6e con GKE y XPK

Si usas comandos de GKE con v6e, puedes usar comandos de Kubernetes o XPK para aprovisionar Cloud TPUs y entrenar o entregar modelos. Consulta Planifica el uso de Cloud TPUs en GKE para obtener información sobre cómo planificar tus configuraciones de Cloud TPU en clústeres de GKE. En las siguientes secciones, se proporcionan comandos para crear un clúster de XPK con compatibilidad para una sola NIC y para varias NIC.

Crea un clúster de XPK con compatibilidad para una sola NIC

export CLUSTER_NAME=xpk-cluster-name
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT_ID} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network=${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
python3 xpk.py cluster create --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE} \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --create-vertex-tensorboard

Descripciones de las marcas de comandos

Variable Descripción
CLUSTER_NAME Es el nombre asignado por el usuario para el clúster de XPK.
ID DEL PROYECTO Nombre del proyectoGoogle Cloud . Usa un proyecto existente o crea uno nuevo. Para obtener más información, consulta Configura tu proyecto de Google Cloud .
ZONA Consulta el documento Regiones y zonas de Cloud TPU para conocer las zonas compatibles.
TPU_TYPE Consulta Tipos de aceleradores.
NUM_SLICES Cantidad de segmentos que deseas crear
CLUSTER_ARGUMENTS La red y la subred que se usarán.

Por ejemplo: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES Es la cantidad de segmentos que se crearán.
NETWORK_NAME Es el nombre de una red secundaria que se usará.
NETWORK_FW_NAME Es el nombre de un firewall de red secundario que se usará.

Crea un clúster de XPK con compatibilidad para varias NIC

export CLUSTER_NAME=xpk-cluster-name
export REGION=your-region
export ZONE=us-east1-d
export PROJECT_ID=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export SUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
gcloud compute networks create ${NETWORK_NAME_1} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_1} \
   --network=${NETWORK_NAME_1} \
   --range=10.11.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_1} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_1} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
# Secondary subnet for multi-nic experience.
# Need custom IP routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
gcloud compute networks create ${NETWORK_NAME_2} \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=${PROJECT_ID}
gcloud compute networks subnets create ${SUBNET_NAME_2} \
   --network=${NETWORK_NAME_2} \
   --range=10.10.0.0/18 \
   --region=${REGION} \
   --project=${PROJECT_ID}
gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
   --network=${NETWORK_NAME_2} \
   --allow tcp,icmp,udp \
   --project=${PROJECT_ID}
gcloud compute routers create ${ROUTER_NAME} \
   --project=${PROJECT_ID} \
   --network=${NETWORK_NAME_2} \
   --region=${REGION}
gcloud compute routers nats create ${NAT_CONFIG} \
   --router=${ROUTER_NAME} \
   --region=${REGION} \
   --auto-allocate-nat-external-ips \
   --nat-all-subnet-ip-ranges \
   --project=${PROJECT_ID} \
   --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"
export NODE_POOL_ARGUMENTS="--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 xpk.py cluster create \
   --cluster=${CLUSTER_NAME} \
   --cluster-cpu-machine-type=e2-standard-8 \
   --num-slices=${NUM_SLICES} \
   --tpu-type=${TPU_TYPE} \
   --zone=${ZONE}  \
   --project=${PROJECT_ID} \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
   --custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
   --create-vertex-tensorboard

Descripciones de las marcas de comandos

Variable Descripción
CLUSTER_NAME Es el nombre asignado por el usuario para el clúster de XPK.
ID DEL PROYECTO Nombre del proyectoGoogle Cloud . Usa un proyecto existente o crea uno nuevo. Para obtener más información, consulta Configura tu proyecto de Google Cloud .
ZONA Consulta el documento Regiones y zonas de Cloud TPU para conocer las zonas compatibles.
TPU_TYPE Consulta Tipos de aceleradores.
NUM_SLICES Cantidad de segmentos que deseas crear
CLUSTER_ARGUMENTS La red y la subred que se usarán.

Por ejemplo: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS Es la red de nodos adicional que se usará.

Por ejemplo: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES Cantidad de segmentos que se crearán (solo es necesario para Multislice).
NETWORK_NAME Es el nombre de una red secundaria que se usará.
NETWORK_FW_NAME Es el nombre de un firewall de red secundario que se usará.

Configura JAX o PyTorch

En los siguientes recursos, se muestra cómo configurar JAX o PyTorch en tu Cloud TPU, según el método de aprovisionamiento y administración que uses:

Para configurar y ejecutar XPK con MaxText, consulta Cómo ejecutar MaxText a gran escala con XPK .

Optimiza el rendimiento de la red

En esta sección, se describe cómo optimizar el rendimiento de tu red configurando la unidad de transmisión máxima (MTU), usando varias NIC para entornos de Multislice y mejorando la configuración de TCP.

Configura la MTU

Para obtener el mejor rendimiento de la red, usa una red con una MTU (unidad de transmisión máxima) de 8,896.

De forma predeterminada, una nube privada virtual (VPC) solo proporciona una MTU de 1,460 bytes, lo que genera un rendimiento de red subóptimo. Puedes configurar la MTU de una red de VPC en cualquier valor entre 1,300 bytes y 8,896 bytes (inclusive). Los tamaños de MTU personalizados comunes son 1,500 bytes (Ethernet estándar) o 8,896 bytes (el máximo posible). Para obtener más información, consulta Tamaños válidos de MTU de la red de VPC.

Para obtener más información sobre cómo cambiar el parámetro de configuración de MTU de una red existente o predeterminada, consulta Cambia la configuración de MTU de una red de VPC.

En el siguiente ejemplo,se crea una red con un MTU de 8, 896 y una regla de firewall correspondiente que permite el tráfico de TCP, ICMP y UDP dentro de la red.

export RESOURCE_NAME=your-resource-name
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
    --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network=${NETWORK_NAME} \
    --allow tcp,icmp,udp --project=${PROJECT_ID}

Reemplaza your-resource-name por un nombre base para la red y el firewall.

Usa la opción de varias NIC para Multislice

Si usas un entorno de Multislice, configura las siguientes variables de entorno, que son obligatorias para una subred secundaria:

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=your-region

Usa los siguientes comandos para crear un enrutamiento de IP personalizado para la red y la subred.

  1. Crea la red secundaria.

    gcloud compute networks create ${NETWORK_NAME_2} --mtu=8896 \
    --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
    
  2. Crea una subred para la red secundaria.

    gcloud compute networks subnets create ${SUBNET_NAME_2} \
    --network=${NETWORK_NAME_2} \
    --range=10.10.0.0/18 --region=${REGION} \
    --project=${PROJECT_ID}
    
  3. Crea una regla de firewall para permitir el tráfico dentro de la subred nueva.

    gcloud compute firewall-rules create ${FIREWALL_RULE_NAME} \
    --network=${NETWORK_NAME_2} --allow tcp,icmp,udp \
    --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}
    
  4. Crea un Cloud Router para la red secundaria.

    gcloud compute routers create ${ROUTER_NAME} \
    --project=${PROJECT_ID} \
    --network=${NETWORK_NAME_2} \
    --region=${REGION}
    
  5. Crea una configuración de NAT para el Cloud Router.

    gcloud compute routers nats create ${NAT_CONFIG} \
    --router=${ROUTER_NAME} \
    --region=${REGION} \
    --auto-allocate-nat-external-ips \
    --nat-all-subnet-ip-ranges \
    --project=${PROJECT_ID} \
    --enable-logging
    

Después de crear un segmento de varias redes, puedes validar que se usen ambas tarjetas de interfaz de red (NIC) configurando un clúster de XPK y agregando la marca --command ifconfig al comando de creación de la carga de trabajo de XPK.

  1. Usa el siguiente comando workload create para mostrar el resultado del comando ifconfig en los registros de la consola de Google Cloud y verifica que eth0 y eth1 tengan el MTU establecido en 8,896.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image | --docker-image your-cloud-image-name} \
        --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
        --tpu-type=${ACCELERATOR_TYPE} \
        --num-slices=${NUM_SLICES}  \
        --on-demand \
        --zone=${ZONE} \
        --project=${PROJECT_ID} \
        --command "ifconfig"

    Si deseas habilitar los registros de depuración o usar Vertex AI TensorBoard, agrega los siguientes argumentos opcionales al comando:

    --enable-debug-logs \
    --use-vertex-tensorboard
  2. Verifica que tanto eth0 como eth1 tengan la MTU establecida en 8,896. Para ello, revisa el resultado de la carga de trabajo de XPK en los registros de la consola de Google Cloud .

Mejora la configuración de TCP

Si aprovisionaste tus Cloud TPU con recursos en cola, puedes ejecutar el siguiente comando para mejorar el rendimiento de la red aumentando los límites del búfer de recepción de TCP.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
    --project "${PROJECT_ID}" \
    --zone "${ZONE}" \
    --node=all \
    --worker=all \
    --command='
    sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"'

Optimiza el rendimiento de la asignación de memoria

La biblioteca tcmalloc se usa de forma predeterminada en las VMs de Cloud TPU para mejorar el rendimiento de los modelos con asignaciones de memoria frecuentes y de gran tamaño. Se configura a través de la variable de entorno LD_PRELOAD.

Sin embargo, para algunas cargas de trabajo (por ejemplo, DLRM con asignaciones de tablas de incorporación muy grandes), tcmalloc puede causar una ralentización. En esos casos, puedes volver a la función malloc estándar anulando la variable LD_PRELOAD en tu sesión de shell antes de ejecutar la secuencia de comandos de entrenamiento:

unset LD_PRELOAD

Usa SkyPilot

Puedes usar Cloud TPU v6e con SkyPilot. SkyPilot es un framework de código abierto que simplifica el proceso de ejecución, administración y escalamiento de cargas de trabajo de IA. Puedes agregar información sobre la ubicación y los precios relacionados con v6e a SkyPilot. Para obtener más información, consulta el ejemplo de TPU v6e de SkyPilot.

Ejemplos de entrenamiento

En las siguientes secciones, se proporcionan ejemplos para entrenar modelos de MaxText, MaxDiffusion y PyTorch en Cloud TPU v6e.

Estos ejemplos se probaron con las siguientes versiones de software:

  • Python 3.10 o una versión posterior
  • Versiones de software nocturnas:
    • JAX nocturno 0.4.32.dev20240912
    • LibTPU nocturna 0.1.dev20240912+nightly
  • Versiones de software estables:
    • JAX y JAX Lib de la versión 0.4.37

Entrena MaxText y MaxDiffusion en Cloud TPU v6e

En las siguientes secciones, se abarca el ciclo de vida del entrenamiento de los modelos MaxText y MaxDiffusion.

En general, los pasos de alto nivel son los siguientes:

  1. Compila la imagen base de la carga de trabajo.
  2. Ejecuta tu carga de trabajo con XPK.
    1. Compila el comando de entrenamiento para la carga de trabajo.
    2. Implementa la carga de trabajo.
  3. Sigue la carga de trabajo y consulta las métricas.
  4. Borra la carga de trabajo de XPK si no es necesaria.
  5. Borra el clúster de XPK cuando ya no lo necesites.

Compila la imagen base

Instala MaxText o MaxDiffusion y compila la imagen de Docker:

  1. Clona el repositorio que deseas usar y cambia al directorio del repositorio:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    
  2. Configura Docker para usar Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Compila la imagen de Docker con el siguiente comando o con una imagen de IA de JAX. Para obtener más información sobre las imágenes de IA de JAX, consulta Imágenes de IA de JAX.

    MaxText:

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    

    MaxDiffusion:

    bash .github/workflows/build_and_upload_images.sh CLOUD_IMAGE_NAME=maxdiffusion_jax_stable_stack MODE=jax_ai_image PROJECT=${PROJECT_ID} LOCAL_IMAGE_NAME=maxdiffusion_jax_stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest
    
  4. Configura tu ID del proyecto en la configuración activa de gcloud CLI:

    gcloud config set project ${PROJECT_ID}
    
  5. Si inicias la carga de trabajo desde una máquina que no tiene la imagen compilada de forma local, sube la imagen.

    1. Establece la variable de entorno CLOUD_IMAGE_NAME:

      export CLOUD_IMAGE_NAME=${USER}_runner
      
    2. Sube la imagen:

      bash docker_upload_runner.sh ${CLOUD_IMAGE_NAME}
      

Ejecuta tu carga de trabajo con XPK

  1. Establece las siguientes variables de entorno si no usas los valores predeterminados establecidos por MaxText o MaxDiffusion:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Crea tu secuencia de comandos del modelo. Esta secuencia de comandos se copiará como un comando de entrenamiento en un paso posterior.

    Aún no ejecutes la secuencia de comandos del modelo.

    MaxText

    MaxText es un LLM de código abierto de alto rendimiento y altamente escalable escrito en Python y JAX puros, y orientado a Google Cloud TPUs y GPUs para el entrenamiento y la inferencia.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python3 -m MaxText.train MaxText/configs/base.yml \
         base_output_directory=${BASE_OUTPUT_DIR} \
         dataset_type=synthetic \
         per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
         enable_checkpointing=false \
         gcs_metrics=true \
         profiler=xplane \
         skip_first_n_steps_for_profiler=5 \
         steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma es una familia de LLMs de código abierto desarrollados por Google DeepMind, basados en la investigación y la tecnología de Gemini.

    python3 -m MaxText.train MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral es un modelo de IA de vanguardia desarrollado por Mistral AI que utiliza una arquitectura de mezcla de expertos (MoE) dispersa.

    python3 -m MaxText.train MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama es una familia de LLMs de código abierto desarrollados por Meta.

    Para ver un ejemplo de cómo ejecutar Llama3 en PyTorch, consulta los modelos de torch_xla en el repositorio de torchprime.

    MaxDiffusion

    MaxDiffusion es una colección de implementaciones de referencia de varios modelos de difusión latentes escritos en Python y JAX puros que se ejecutan en dispositivos XLA, incluidas las Cloud TPU y las GPU. Stable Diffusion es un modelo latente de texto a imagen que genera imágenes fotorrealistas a partir de cualquier entrada de texto.

    Debes instalar una rama de Git específica para ejecutar MaxDiffusion, como se muestra en la siguiente secuencia de comandos de entrenamiento.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout 4a8155ec0129512812b31930f0a91c6d5a141103
    && pip install -r requirements.txt && pip install .
    && pip install huggingface_hub==0.30.2 && OUT_DIR=${BASE_OUTPUT_DIR}
    && python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR} \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash \
        run_name=sdxl-ddp-v6e
    
  3. Exporta las siguientes variables:

    export CLUSTER_NAME=CLUSTER_NAME
    export ACCELERATOR_TYPE=ACCELERATOR_TYPE
    export NUM_SLICES=NUM_SLICES
    export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT

    Descripciones de las variables de entorno

    Variable Descripción
    CLUSTER_NAME Es el nombre de tu clúster de XPK.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de la Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    NUM_SLICES Es la cantidad de porciones de TPU.
    YOUR_MODEL_SCRIPT Es la secuencia de comandos del modelo que se ejecutará como un comando de entrenamiento.
  4. Ejecuta el modelo con la secuencia de comandos que creaste en el paso anterior. Debes especificar la marca --base-docker-image para usar la imagen base de MaxText o la marca --docker-image y la imagen que deseas usar.

    Puedes agregar las siguientes marcas opcionales:

    • Puedes habilitar el registro de depuración incluyendo la marca --enable-debug-logs. Para obtener más información, consulta Cómo depurar JAX en MaxText.
    • Puedes crear un experimento de Vertex AI para subir datos a Vertex AI TensorBoard incluyendo la marca --use-vertex-tensorboard. Para obtener más información, consulta Supervisa JAX en MaxText con Vertex AI.
    python3 xpk.py workload create \
      --cluster ${CLUSTER_NAME} \
      {--base-docker-image maxtext_base_image | --docker-image gcr.io/${PROJECT_ID}/${CLOUD_IMAGE_NAME}:latest} \
      --workload=${USER}-xpk-${ACCELERATOR_TYPE}-${NUM_SLICES} \
      --tpu-type=${ACCELERATOR_TYPE} \
      --num-slices=${NUM_SLICES}  \
      --on-demand \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --command="${YOUR_MODEL_SCRIPT}"

    El resultado incluye un vínculo para seguir tu carga de trabajo. Abre el vínculo y haz clic en la pestaña Registros para hacer un seguimiento de tu carga de trabajo en tiempo real.

Cómo depurar JAX en MaxText

Usa comandos de XPK complementarios para diagnosticar por qué no se ejecuta el clúster o la carga de trabajo:

Supervisa JAX en MaxText con Vertex AI

Para usar TensorBoard, tu cuenta de usuario Google Cloud debe tener el rol de aiplatform.user. Ejecuta el siguiente comando para otorgar este rol:

gcloud projects add-iam-policy-binding your-project-id \
   --member='user:your-email' \
   --role='roles/aiplatform.user'

Visualiza datos de perfil y escalares a través de TensorBoard administrado por Vertex AI.

  1. Aumenta las solicitudes de administración de recursos (CRUD) para la zona que usas de 600 a 5,000. Esto podría no ser un problema para cargas de trabajo pequeñas que usan menos de 16 VMs.

  2. Instala dependencias como cloud-accelerator-diagnostics para Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Crea tu clúster de XPK con la marca --create-vertex-tensorboard, como se documenta en Crea Vertex AI TensorBoard. También puedes ejecutar este comando en clústeres existentes.

  4. Crea tu experimento de Vertex AI cuando ejecutes tu carga de trabajo de XPK con la marca --use-vertex-tensorboard y la marca opcional --experiment-name. Para obtener la lista completa de pasos, consulta Crea un experimento de Vertex AI para subir datos a Vertex AI TensorBoard.

Los registros incluyen un vínculo a un Vertex AI TensorBoard, similar al siguiente:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

También puedes encontrar el vínculo de TensorBoard de Vertex AI en la consola de Google Cloud . Ve a Vertex AI Experiments en la consola de Google Cloud . Selecciona la región adecuada en el menú desplegable.

El directorio de TensorBoard también se escribe en el bucket de Cloud Storage que especificaste con ${BASE_OUTPUT_DIR}.

Borra tu carga de trabajo de XPK

Usa el comando xpk workload delete para borrar una o más cargas de trabajo según el prefijo o el estado del trabajo. Este comando puede ser útil si enviaste cargas de trabajo de XPK que ya no necesitan ejecutarse o si tienes trabajos atascados en la cola.

Borra tu clúster de XPK

Usa el comando xpk cluster delete para borrar el clúster:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
    --zone=${ZONE} --project=${PROJECT_ID}

Resultados de las comparativas de MaxDiffusion

Ejecutamos la secuencia de comandos de entrenamiento de MaxDiffusion en una v6e-4, una v6e-16 y dos v6e-16. En la siguiente tabla, se muestran las capacidades de procesamiento medidas.

v6e-4 v6e-16 Dos v6e-16
Pasos de entrenamiento 0.069 0.073 0.13
Tamaño del lote global 8 32 64
Capacidad de procesamiento (ejemplos/s) 115.9 438.4 492.3

Entrena modelos de Llama con PyTorch/XLA en Cloud TPU v6e

En esta sección, se describe cómo entrenar modelos de Llama con PyTorch/XLA en Cloud TPU v6e usando el conjunto de datos de WikiText.

Obtén acceso a Hugging Face y al modelo de Llama 3

Para este ejemplo, necesitas un token de acceso de usuario de Hugging Face. Para obtener información sobre cómo crear tokens de acceso de usuario, consulta la documentación de Hugging Face sobre tokens de acceso de usuario.

También necesitas permiso para acceder al modelo Llama-3-8B en Hugging Face. Para obtener acceso, ve al modelo Meta-Llama-3-8B en Hugging Face y solicita acceso.

Crea una VM de Cloud TPU

Crea una Cloud TPU v6e con 8 chips para este ejemplo.

  1. Configure las variables de entorno:

    export PROJECT_ID=your-project-id
    export TPU_NAME=your-tpu-name
    export ZONE=us-east1-d
    export ACCELERATOR_TYPE=v6e-8
    export RUNTIME_VERSION=v2-alpha-tpuv6e

    Descripciones de las variables de entorno

    Variable Descripción
    PROJECT_ID El ID de tu proyecto Google Cloud . Usa un proyecto existente o crea uno nuevo.
    TPU_NAME Nombre de la TPU.
    ZONE Es la zona en la que se creará la VM de TPU. Para obtener más información sobre las zonas admitidas, consulta Regiones y zonas de TPU.
    ACCELERATOR_TYPE El tipo de acelerador especifica la versión y el tamaño de la Cloud TPU que deseas crear. Para obtener más información sobre los tipos de aceleradores compatibles con cada versión de TPU, consulta Versiones de TPU.
    RUNTIME_VERSION Versión de software de Cloud TPU.

  2. Crea una VM de Cloud TPU:

    gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --version=${RUNTIME_VERSION} \
       --accelerator-type=${ACCELERATOR_TYPE} \
       --zone=${ZONE} \
       --project=${PROJECT_ID}

Instalación

Instala la bifurcación de pytorch-tpu/transformers de Transformers de Hugging Face y las dependencias. Este ejemplo se probó con las siguientes versiones de dependencias:

  • torch: Compatible con la versión 2.5.0
  • torch_xla[tpu]: Compatible con la versión 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
   --project=${PROJECT_ID} \
   --zone ${ZONE} \
   --worker=all \
   --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   sudo pip3 install -e .
   pip3 install datasets
   pip3 install evaluate
   pip3 install scikit-learn
   pip3 install accelerate
   pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
   pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/'

Cómo configurar archivos de configuración del modelo

El comando de entrenamiento de la siguiente sección, Ejecuta el modelo, usa dos archivos de configuración JSON para definir los parámetros del modelo y la configuración de Fully Sharded Data Parallel (FSDP). El sharding de FSDP te permite usar un tamaño de lote más grande durante el entrenamiento, ya que fragmenta los pesos del modelo en varias TPU. Cuando se entrena con modelos más pequeños, podría ser suficiente usar el paralelismo de datos y replicar los pesos en cada dispositivo. Si deseas obtener más información para fragmentar tensores en dispositivos con PyTorch/XLA, consulta la guía del usuario de SPMD de PyTorch/XLA.

  1. Crea el archivo de configuración de parámetros del modelo. A continuación, se muestra la configuración de los parámetros del modelo para Llama-3-8B. Para otros modelos, busca el archivo de configuración en Hugging Face. Por ejemplo, consulta la configuración de Llama-2-7B.

    cat > llama-config.json << EOF
    {
      "architectures": [
        "LlamaForCausalLM"
      ],
      "attention_bias": false,
      "attention_dropout": 0.0,
      "bos_token_id": 128000,
      "eos_token_id": 128001,
      "hidden_act": "silu",
      "hidden_size": 4096,
      "initializer_range": 0.02,
      "intermediate_size": 14336,
      "max_position_embeddings": 8192,
      "model_type": "llama",
      "num_attention_heads": 32,
      "num_hidden_layers": 32,
      "num_key_value_heads": 8,
      "pretraining_tp": 1,
      "rms_norm_eps": 1e-05,
      "rope_scaling": null,
      "rope_theta": 500000.0,
      "tie_word_embeddings": false,
      "torch_dtype": "bfloat16",
      "transformers_version": "4.40.0.dev0",
      "use_cache": false,
      "vocab_size": 128256
    }
    EOF
    
  2. Crea el archivo de configuración de FSDP:

    cat > fsdp-config.json << EOF
    {
      "fsdp_transformer_layer_cls_to_wrap": [
        "LlamaDecoderLayer"
      ],
      "xla": true,
      "xla_fsdp_v2": true,
      "xla_fsdp_grad_ckpt": true
    }
    EOF
    

    Para obtener más información sobre FSDP, consulta Fully Sharded Data Parallel using SPMD .

  3. Sube los archivos de configuración a tus VMs de Cloud TPU con el siguiente comando:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json ${TPU_NAME}:. \
       --worker=all \
       --project=${PROJECT_ID} \
       --zone=${ZONE}

Ejecuta el modelo

Con los archivos de configuración que creaste en la sección anterior, ejecuta la secuencia de comandos run_clm.py para entrenar el modelo Llama-3-8B en el conjunto de datos de WikiText. La secuencia de comandos de entrenamiento tarda alrededor de 10 minutos en ejecutarse en una Cloud TPU v6e-8.

  1. Accede a Hugging Face en tu Cloud TPU con el siguiente comando:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       pip3 install "huggingface_hub[cli]"
       huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Ejecuta el entrenamiento de modelos:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --project=${PROJECT_ID} \
       --zone ${ZONE} \
       --worker=all \
       --command='
       export PJRT_DEVICE=TPU
       export XLA_USE_SPMD=1
       export ENABLE_PJRT_COMPATIBILITY=true
       # Optional variables for debugging:
       export XLA_IR_DEBUG=1
       export XLA_HLO_DEBUG=1
       export PROFILE_EPOCH=0
       export PROFILE_STEP=3
       export PROFILE_DURATION_MS=100000
       # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
       export PROFILE_LOGDIR=PROFILE_PATH
       python3 transformers/examples/pytorch/language-modeling/run_clm.py \
         --dataset_name wikitext \
         --dataset_config_name wikitext-2-raw-v1 \
         --per_device_train_batch_size 16 \
         --do_train \
         --output_dir /home/$USER/tmp/test-clm \
         --overwrite_output_dir \
         --config_name /home/$USER/llama-config.json \
         --cache_dir /home/$USER/cache \
         --tokenizer_name meta-llama/Meta-Llama-3-8B \
         --block_size 8192 \
         --optim adafactor \
         --save_strategy no \
         --logging_strategy no \
         --fsdp "full_shard" \
         --fsdp_config /home/$USER/fsdp-config.json \
         --torch_dtype bfloat16 \
         --dataloader_drop_last yes \
         --flash_attention \
         --max_steps 20'

Solución de problemas de PyTorch/XLA

Si configuraste las variables opcionales para la depuración en la sección anterior, el perfil del modelo se almacenará en la ubicación especificada por la variable PROFILE_LOGDIR. Puedes extraer el archivo xplane.pb almacenado en esta ubicación y usar tensorboard para ver los perfiles en tu navegador con las instrucciones de TensorBoard.

Si PyTorch/XLA no funciona como se espera, consulta la guía de solución de problemas, que incluye sugerencias para depurar, generar perfiles y optimizar tu modelo.