Introducción a Trillium (v6e)

En esta documentación, la API de TPU y los registros, v6e se usa para referirse a Trillium. v6e representa la 6ª generación de TPU de Google.

Con 256 chips por Pod, la arquitectura v6e comparte muchas similitudes con la v5e. Este sistema está optimizado para el entrenamiento, la optimización y la entrega de transformadores, texto a imagen y redes neuronales convolucionales (CNN).

Consulta el documento v6e para obtener información sobre la arquitectura y las configuraciones del sistema v6e.

En este documento de introducción, se enfocan los procesos de entrenamiento y publicación de modelos con los frameworks de JAX, PyTorch o TensorFlow. Con cada framework, puedes aprovisionar TPU con recursos en cola o Google Kubernetes Engine (GKE). La configuración de GKE se puede realizar con XPK o comandos de GKE.

Procedimiento general para entrenar o entregar un modelo con la versión v6e

  1. Prepara un Google Cloud proyecto
  2. Capacidad segura
  3. Configura tu entorno de TPU
  4. Aprovisiona el entorno de Cloud TPU
  5. Ejecuta una carga de trabajo de entrenamiento o inferencia de modelos
  6. Realiza una limpieza

Prepara un Google Cloud proyecto

  1. Accede a tu Cuenta de Google. Si aún no lo hiciste, regístrate para obtener una nueva cuenta.
  2. En la consola de Google Cloud, selecciona o crea un proyecto de Cloud en la página del selector de proyectos.
  3. Habilita la facturación para tu proyecto de Google Cloud. La facturación es obligatoria para todo el uso de Google Cloud.
  4. Instala los componentes de gcloud alpha.
  5. Ejecuta el siguiente comando para instalar la versión más reciente de los componentes de gcloud.

    gcloud components update
    
  6. Habilita la API de TPU con el siguiente comando gcloud en Cloud Shell. También puedes habilitarlo desde la consola de Google Cloud.

    gcloud services enable tpu.googleapis.com
    
  7. Habilita permisos con la cuenta de servicio de TPU para la API de Compute Engine

    Las cuentas de servicio permiten que el servicio de Cloud TPU acceda a otros servicios de Google Cloud. Una cuenta de servicio administrada por el usuario es una práctica recomendada de Google Cloud. Sigue estas guías para crear y otorgar roles. Se requieren los siguientes roles:

    • Administrador de TPU
    • Administrador de almacenamiento
    • Escritor de registros
    • Escritor de métricas de Monitoring

    a. Configura los permisos de XPK con tu cuenta de usuario de GKE: XPK.

  8. Realiza la autenticación con tu Cuenta de Google y establece el ID y la zona del proyecto predeterminados.
    auth login autoriza a gcloud a acceder a Google Cloud con credenciales de usuario de Google.
    PROJECT_ID es el Google Cloud nombre del proyecto.
    ZONE es la zona en la que deseas crear la TPU.

     gcloud auth login
     gcloud config set project ${PROJECT_ID}
     gcloud config set compute/zone ${ZONE}
    
  9. Crea una identidad de servicio para la VM de TPU.

     gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    

Cómo proteger la capacidad

Comunícate con el equipo de asistencia de ventas o de cuentas de Cloud TPU para solicitar la cuota de TPU y responder cualquier pregunta sobre la capacidad.

Aprovisiona el entorno de Cloud TPU

Las TPU v6e se pueden aprovisionar y administrar con GKE, con GKE y XPK (una herramienta de wrapper de CLI sobre GKE) o como recursos en cola.

Requisitos previos

  • Verifica que tu proyecto tenga suficiente cuota de TPUS_PER_TPU_FAMILY, que especifica la cantidad máxima de chips a los que puedes acceder en tu proyecto deGoogle Cloud .
  • La versión 6e se probó con la siguiente configuración:
    • Python 3.10 o versiones posteriores
    • Versiones de software nocturnas:
      • JAX nocturno 0.4.32.dev20240912
      • LibTPU nocturna 0.1.dev20240912+nightly
    • Versiones de software estables:
      • JAX + JAX Lib de la versión 0.4.37
  • Verifica que tu proyecto tenga suficiente cuota de TPU para lo siguiente:

    • Cuota de VM de TPU
    • Cuota de direcciones IP
    • Quota de Hyperdisk Balanced

  • Permisos del proyecto del usuario

Variables de entorno

En Cloud Shell, crea las siguientes variables de entorno:

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-east1-d
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=your-service-account
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for provisioning Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the default network becoming overloaded.

export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Descripciones de las marcas de comandos

Variable Descripción
NODE_ID El ID asignado por el usuario de la TPU que se crea cuando se asigna la solicitud de recurso en fila.
ID DEL PROYECTO Google Cloud Nombre del proyecto. Usa un proyecto existente o crea uno nuevo en
ZONA Consulta el documento Regiones y zonas de TPU para conocer las zonas compatibles.
ACCELERATOR_TYPE Consulta Tipos de aceleradores.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Esta es la dirección de correo electrónico de tu cuenta de servicio que puedes encontrar en la Google Cloud Console -> IAM -> Cuentas de servicio.

Por ejemplo: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

NUM_SLICES Es la cantidad de rebanadas que se deben crear (solo es necesario para Multislice).
QUEUED_RESOURCE_ID El ID de texto asignado por el usuario de la solicitud de recursos en cola.
VALID_DURATION Es la duración durante la cual la solicitud de recursos en cola es válida.
NETWORK_NAME Es el nombre de una red secundaria que se usará.
NETWORK_FW_NAME Es el nombre de un firewall de red secundario que se usará.

Optimizaciones del rendimiento de la red

Para obtener el mejor rendimiento, usa una red con 8,896 MTU (unidad de transmisión máxima).

De forma predeterminada, una nube privada virtual (VPC) solo proporciona una MTU de 1,460 bytes, lo que proporcionará un rendimiento de red poco óptimo. Puedes configurar la MTU de una red de VPC en cualquier valor entre 1,300 bytes y 8,896 bytes (inclusive). Los tamaños de MTU personalizados comunes son 1,500 bytes (Ethernet estándar) o 8,896 bytes (el máximo posible). Para obtener más información, consulta Tamaños válidos de MTU de la red de VPC.

Para obtener más información sobre cómo cambiar la configuración de MTU de una red existente o predeterminada, consulta Cambia la configuración de MTU de una red de VPC.

En el siguiente ejemplo, se crea una red con 8,896 MTU.

export RESOURCE_NAME=RESOURCE_NAME
export NETWORK_NAME=${RESOURCE_NAME}-privatenetwork
export NETWORK_FW_NAME=${RESOURCE_NAME}-privatefirewall
export PROJECT=X
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT_ID} \
 --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME}
 --allow tcp,icmp,udp --project=${PROJECT}

Cómo usar varias NIC (opción para Multislice)

Las siguientes variables de entorno son necesarias para una subred secundaria cuando usas un entorno de Multislice.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2

Usa los siguientes comandos para crear enrutamiento de IP personalizado para la red y la subred.

gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
   --bgp-routing-mode=regional --subnet-mode=custom --project=${PROJECT_ID}
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 --region="${REGION}" \
   --project=$PROJECT

gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project=${PROJECT_ID}

gcloud compute routers create "${ROUTER_NAME}" \
  --project="${PROJECT_ID}" \
  --network="${NETWORK_NAME_2}" \
  --region="${REGION}"

gcloud compute routers nats create "${NAT_CONFIG}" \
  --router="${ROUTER_NAME}" \
  --region="${REGION}" \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project="${PROJECT_ID}" \
  --enable-logging

Una vez que se crea una porción de varias redes, puedes validar que se usen ambas NIC configurando un clúster de XPK y ejecutando --command ifconfig como parte de la carga de trabajo de XPK.

Usa el siguiente comando xpk workload para mostrar el resultado del comando ifconfig en los registros de la consola de Cloud y verifica que tanto eth0 como eth1 tengan mtu=8896.

python3 xpk.py workload create \
   --cluster CLUSTER_NAME \
   (--base-docker-image maxtext_base_image|--docker-image CLOUD_IMAGE_NAME \
   --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone $ZONE \
   --project $PROJECT_ID \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

Verifica que tanto eth0 como eth1 tengan mtu=8,896. Una forma de verificar que tienes varios NIC en ejecución es ejecutar el comando --command "ifconfig" como parte de la carga de trabajo de XPK. Luego, observa el resultado impreso de esa carga de trabajo de xpk en los registros de la consola de Cloud y verifica que tanto eth0 como eth1 tengan mtu=8896.

Configuración de TCP mejorada

En el caso de las TPU creadas con la interfaz de recursos en cola, puedes ejecutar el siguiente comando para mejorar el rendimiento de la red aumentando los límites del búfer de recepción de TCP.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
  --project "$PROJECT" \
  --zone "$ZONE" \
  --node=all \
  --command='sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' \
  --worker=all

Aprovisionamiento con recursos en cola

La capacidad asignada se puede aprovisionar con el comando create de recursos en cola.

  1. Crea una solicitud de recurso en cola de TPU.

    La marca --reserved solo es necesaria para los recursos reservados, no para los recursos on demand.

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --valid-until-duration ${VALID_DURATION} \
      --service-account ${SERVICE_ACCOUNT} \
      [--reserved]
    
      # The following flags are only needed if you are using Multislice.
      --node-count node-count  # Number of slices in a Multislice \
      --node-prefix node-prefix # An optional user-defined node prefix;
       the default is QUEUED_RESOURCE_ID.

    Si la solicitud de recursos en cola se crea correctamente, el estado en el campo "response" será "WAITING_FOR_RESOURCES" o "FAILED". Si la solicitud de recursos en cola está en el estado "WAITING_FOR_RESOURCES", significa que el recurso se agregó a la cola y se aprovisionará cuando haya suficiente capacidad de TPU asignada. Si la solicitud de recursos en cola está en el estado "FAILED", el motivo de la falla aparecerá en el resultado. La solicitud de recursos en cola vencerá si no se aprovisiona un v6e dentro de la duración especificada y el estado se convierte en “FAILED”. Consulta la documentación pública de Recursos en cola para obtener más información.

    Cuando tu solicitud de recursos en cola esté en el estado "ACTIVO", podrás conectarte a tus VMs de TPU con SSH. Usa los comandos list o describe para consultar el estado de tu recurso en cola.

    gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project ${PROJECT_ID} --zone ${ZONE}
    

    Cuando el recurso en cola está en el estado "ACTIVE", el resultado es similar al siguiente:

      state:
       state: ACTIVE
    
  2. Administra tus VMs de TPU. Para conocer las opciones de administración de tus VMs de TPU, consulta Cómo administrar VMs de TPU.

  3. Conéctate a tus VMs de TPU con SSH

    Puedes instalar objetos binarios en cada VM de TPU de tu porción de TPU y ejecutar código. Consulta la sección Tipos de VM para determinar cuántas VMs tendrá tu fragmento.

    Para instalar los objetos binarios o ejecutar código, puedes usar SSH para conectarte a una VM con el comando tpu-vm ssh.

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --node=all # add this flag if you are using Multislice
    

    Para usar SSH y conectarte a una VM específica, usa la marca --worker que sigue a un índice basado en 0:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
    

    Si tienes formas de rebanada superiores a 8 chips, tendrás varias VM en una rebanada. En este caso, usa los parámetros --worker=all y --command en el comando gcloud alpha compute tpus tpu-vm ssh para ejecutar un comando en todas las VMs de forma simultánea. Por ejemplo:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
      --zone  ${ZONE} --worker=all \
      --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Borrar un recurso en cola

    Borra un recurso en cola al final de la sesión o quita las solicitudes de recursos en cola que estén en el estado "FAILED". Para borrar un recurso en cola, borra la porción y, luego, la solicitud de recurso en cola en 2 pasos:

    gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \
     --zone=${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
     --project ${PROJECT_ID} --zone ${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
    

Aprovisiona TPU v6e con GKE o XPK

Si usas comandos de GKE con v6e, puedes usar comandos de Kubernetes o XPK para aprovisionar TPU y entrenar o entregar modelos. Consulta Planifica las TPU en GKE para obtener información sobre cómo planificar tus configuraciones de TPU en los clústeres de GKE. En las siguientes secciones, se proporcionan comandos para crear un clúster de XPK con compatibilidad con una sola NIC y con varias NIC.

Comandos para crear un clúster XPK con compatibilidad con una sola NIC

export CLUSTER_NAME xpk-cluster-name
export ZONE=us-central2-b
export PROJECT=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME=${CLUSTER_NAME}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw
   gcloud compute networks create ${NETWORK_NAME} \
   --mtu=8896 \
   --project=${PROJECT} \
   --subnet-mode=auto \
   --bgp-routing-mode=regional
   gcloud compute firewall-rules create ${NETWORK_FW_NAME} \
   --network ${NETWORK_NAME} \
   --allow tcp,icmp,udp \
   --project=${PROJECT}
export CLUSTER_ARGUMENTS="--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"
   python3 xpk.py cluster create --cluster $CLUSTER_NAME \
   --cluster-cpu-machine-type=n1-standard-8 \
   --num-slices=$NUM_SLICES \
   --tpu-type=$TPU_TYPE \
   --zone=$ZONE  \
   --project=$PROJECT \
   --on-demand \
   --custom-cluster-arguments="${CLUSTER_ARGUMENTS}"  \
   --create-vertex-tensorboard

Descripciones de las marcas de comandos

Variable Descripción
CLUSTER_NAME Es el nombre asignado por el usuario para el clúster de XPK.
ID DEL PROYECTO Google Cloud Nombre del proyecto. Usa un proyecto existente o crea uno nuevo en
ZONA Consulta el documento Regiones y zonas de TPU para conocer las zonas compatibles.
TPU_TYPE Consulta Tipos de aceleradores.
NUM_SLICES La cantidad de divisiones que deseas crear
CLUSTER_ARGUMENTS La red y la subred que se usarán.

Por ejemplo: "--network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}"

NUM_SLICES Es la cantidad de rebanadas que se crearán.
NETWORK_NAME Es el nombre de una red secundaria que se usará.
NETWORK_FW_NAME Es el nombre de un firewall de red secundario que se usará.

Comandos para crear un clúster de XPK con compatibilidad con varias NIC

export CLUSTER_NAME xpk-cluster-name
export ZONE=us-central2-b
export PROJECT=your-project-id
export TPU_TYPE=v6e-256
export NUM_SLICES=2

export NETWORK_NAME_1=${CLUSTER_NAME}-mtu9k-1-${ZONE}
export exportSUBNET_NAME_1=${CLUSTER_NAME}-privatesubnet-1-${ZONE}
export NETWORK_FW_NAME_1=${NETWORK_NAME_1}-fw-1-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-1-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-1-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-1-${ZONE}
   gcloud compute networks create "${NETWORK_NAME_1}" \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=$PROJECT
   gcloud compute networks subnets create "${SUBNET_NAME_1}" \
   --network="${NETWORK_NAME_1}" \
   --range=10.11.0.0/18 \
   --region="${REGION}" \
   --project=$PROJECT
   gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_1}" \
   --allow tcp,icmp,udp \
   --project="${PROJECT}"
  gcloud compute routers create "${ROUTER_NAME}" \
    --project="${PROJECT}" \
    --network="${NETWORK_NAME_1}" \
    --region="${REGION}"
  gcloud compute routers nats create "${NAT_CONFIG}" \
     --router="${ROUTER_NAME}" \
     --region="${REGION}" \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project="${PROJECT}" \
     --enable-logging
Secondary subnet for multi-nic experience. Need custom ip routing to be different from the first network's subnet.

export NETWORK_NAME_2=${CLUSTER_NAME}-privatenetwork-2-${ZONE}
export SUBNET_NAME_2=${CLUSTER_NAME}-privatesubnet-2-${ZONE}
export FIREWALL_RULE_NAME=${CLUSTER_NAME}-privatefirewall-2-${ZONE}
export ROUTER_NAME=${CLUSTER_NAME}-network-2-${ZONE}
export NAT_CONFIG=${CLUSTER_NAME}-natconfig-2-${ZONE}
   gcloud compute networks create "${NETWORK_NAME_2}" \
   --mtu=8896 \
   --bgp-routing-mode=regional \
   --subnet-mode=custom \
   --project=$PROJECT
   gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 \
   --region="${REGION}" \
   --project=$PROJECT
   gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" \
   --allow tcp,icmp,udp \
   --project="${PROJECT}"
   gcloud compute routers create "${ROUTER_NAME}" \
     --project="${PROJECT}" \
     --network="${NETWORK_NAME_2}" \
     --region="${REGION}"
   gcloud compute routers nats create "${NAT_CONFIG}" \
     --router="${ROUTER_NAME}" \
     --region="${REGION}" \
     --auto-allocate-nat-external-ips \
     --nat-all-subnet-ip-ranges \
     --project="${PROJECT}" \
     --enable-logging
export CLUSTER_ARGUMENTS="--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking
--network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

export NODE_POOL_ARGUMENTS="--additional-node-network
network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"
python3 ~/xpk/xpk.py cluster create \
--cluster $CLUSTER_NAME \
--num-slices=$NUM_SLICES \
--tpu-type=$TPU_TYPE \
--zone=$ZONE  \
--project=$PROJECT \
--on-demand \
--custom-cluster-arguments="${CLUSTER_ARGUMENTS}" \
--custom-nodepool-arguments="${NODE_POOL_ARGUMENTS}" \
--create-vertex-tensorboard

Descripciones de las marcas de comandos

Variable Descripción
CLUSTER_NAME Es el nombre asignado por el usuario para el clúster de XPK.
ID DEL PROYECTO Google Cloud Nombre del proyecto. Usa un proyecto existente o crea uno nuevo en
ZONA Consulta el documento Regiones y zonas de TPU para conocer las zonas compatibles.
TPU_TYPE Consulta Tipos de aceleradores.
NUM_SLICES La cantidad de divisiones que deseas crear
CLUSTER_ARGUMENTS La red y la subred que se usarán.

Por ejemplo: "--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}"

NODE_POOL_ARGUMENTS Es la red de nodos adicional que se usará.

Por ejemplo: "--additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}"

NUM_SLICES Es la cantidad de rebanadas que se deben crear (solo es necesario para Multislice).
NETWORK_NAME Es el nombre de una red secundaria que se usará.
NETWORK_FW_NAME Es el nombre de un firewall de red secundario que se usará.

Configuración del framework

En esta sección, se describe el proceso de configuración general para el entrenamiento de modelos de AA con los frameworks JAX, PyTorch o TensorFlow. Puedes aprovisionar TPUs con recursos en cola o GKE. La configuración de GKE se puede realizar con XPK o comandos de Kubernetes.

Configuración para JAX

En esta sección, se proporcionan ejemplos para ejecutar cargas de trabajo de JAX en GKE, con o sin XPK, así como para usar recursos en cola.

Configura JAX con GKE

En el siguiente ejemplo, se configura un host único de 2 × 2 con un archivo YAML de Kubernetes.

Porción única en un solo host

apiVersion: v1
kind: Pod
metadata:
  name: tpu-pod-jax-v6e-a
spec:
  restartPolicy: Never
  nodeSelector:
    cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
    cloud.google.com/gke-tpu-topology: 2x2
  containers:
  - name: tpu-job
    image: python:3.10
    securityContext:
      privileged: true
    command:
    - bash
    - -c
    - |
      pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
      JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python3 -c 'import jax; print("Total TPU chips:", jax.device_count())'
    resources:
      requests:
        google.com/tpu: 4
      limits:
        google.com/tpu: 4

Cuando se complete correctamente, deberías ver el siguiente mensaje en el registro de GKE:

Total TPU chips: 4

Porción única en varios hosts

En el siguiente ejemplo, se configura un grupo de nodos multihost 4 × 4 con un archivo YAML de Kubernetes.

apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-available-chips
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-available-chips
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
        cloud.google.com/gke-tpu-topology: 4x4
      containers:
      - name: tpu-job
        image: python:3.10
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        command:
        - bash
        - -c
        - |
          pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
          JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4

Cuando se complete correctamente, deberías ver el siguiente mensaje en el registro de GKE:

Total TPU chips: 16

Porciones múltiples en varios hosts

En el siguiente ejemplo, se configuran dos grupos de nodos multihost 4 × 4 con un archivo YAML de Kubernetes.

Como requisito previo, debes instalar JobSet v0.2.3 o una versión posterior.

apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice-job
  annotations:
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 4
  replicatedJobs:
    - name: slice
      replicas: 2
      template:
        spec:
          parallelism: 4
          completions: 4
          backoffLimit: 0
          template:
            spec:
              hostNetwork: true
              dnsPolicy: ClusterFirstWithHostNet
              nodeSelector:
                cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
                cloud.google.com/gke-tpu-topology: 4x4
              hostNetwork: true
              containers:
              - name: jax-tpu
                image: python:3.10
                ports:
                - containerPort: 8471
                - containerPort: 8080
                - containerPort: 8431
                securityContext:
                  privileged: true
                command:
                - bash
                - -c
                - |
                  pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                  JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true python -c 'import jax; print("Total TPU chips:", jax.device_count())'
                resources:
                  limits:
                   google.com/tpu: 4
                  requests:
                   google.com/tpu: 4

Cuando se complete correctamente, deberías ver el siguiente mensaje en el registro de GKE:

Total TPU chips: 32

Para obtener más información, consulta Cómo ejecutar una carga de trabajo de porciones múltiples en la documentación de GKE.

Para mejorar el rendimiento, habilita hostNetwork.

Multi-NIC

Para aprovechar la NIC múltiple en GKE, el manifiesto del pod de Kubernetes debe tener anotaciones adicionales. El siguiente es un manifiesto de ejemplo de carga de trabajo de varias NIC que no es de TPU.

apiVersion: v1
kind: Pod
metadata:
  name: sample-netdevice-pod-1
  annotations:
    networking.gke.io/default-interface: 'eth0'
    networking.gke.io/interfaces: |
      [
        {"interfaceName":"eth0","network":"default"},
        {"interfaceName":"eth1","network":"netdevice-network"}
      ]
spec:
  containers:
  - name: sample-netdevice-pod
    image: busybox
    command: ["sleep", "infinity"]
    ports:
    - containerPort: 80
  restartPolicy: Always
  tolerations:
  - key: "google.com/tpu"
    operator: "Exists"
    effect: "NoSchedule"

Si exec en el pod de Kubernetes, deberías ver la NIC adicional con el siguiente código.

$ k exec --stdin --tty sample-netdevice-pod-1 -- /bin/sh
/ # ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
2: eth0@if11: <BROADCAST,MULTICAST,UP,LOWER_UP,M-DOWN> mtu 1460 qdisc noqueue
    link/ether da:be:12:67:d2:25 brd ff:ff:ff:ff:ff:ff
    inet 10.124.2.6/24 brd 10.124.2.255 scope global eth0
       valid_lft forever preferred_lft forever
3: eth1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1460 qdisc mq qlen 1000
    link/ether 42:01:ac:18:00:04 brd ff:ff:ff:ff:ff:ff
    inet 172.24.0.4/32 scope global eth1
       valid_lft forever preferred_lft forever

Configura JAX con GKE y XPK

Consulta un ejemplo en el archivo README de xpk.

Para configurar y ejecutar XPK con MaxText, consulta: Cómo ejecutar MaxText.

Configura JAX con recursos en cola

Instala JAX en todas las VMs de TPU de tu porción o porciones de forma simultánea con gcloud alpha compute tpus tpu-vm ssh. Para Multislice, agrega --node=all.


gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'

Puedes ejecutar el siguiente código de Python para verificar cuántos núcleos de TPU están disponibles en tu fragmento y probar que todo esté instalado correctamente (los resultados que se muestran aquí se generaron con un fragmento v6e-16):

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

El resultado es similar a este:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() muestra la cantidad total de chips en la porción determinada. jax.local_device_count() indica la cantidad de chips a los que puede acceder una sola VM en esta porción.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout 975fdb7dbddaa9a53ad72a421cdb487dcdc491a3 &&
   && pip install -r requirements.txt  && pip install . '

Cómo solucionar problemas de configuración de JAX

Una sugerencia general es habilitar el registro detallado en el manifiesto de tu carga de trabajo de GKE. Luego, proporciona los registros al equipo de asistencia de GKE.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Mensajes de error

no endpoints available for service 'jobset-webhook-service'

Este error significa que el conjunto de trabajos no se instaló correctamente. Verifica si se están ejecutando los pods de Kubernetes de la implementación de jobset-controller-manager. Para obtener más información, consulta la documentación de solución de problemas de JobSet.

TPU initialization failed: Failed to connect

Asegúrate de que la versión de tu nodo de GKE sea 1.30.4-gke.1348000 o posterior (no se admite GKE 1.31).

Configuración para PyTorch

En esta sección, se describe cómo comenzar a usar PJRT en v6e con PyTorch/XLA. La versión recomendada de Python es 3.10.

Configura PyTorch con GKE y XPK

Puedes usar el siguiente contenedor de Docker con XPK, que ya tiene instaladas las dependencias de PyTorch:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Para crear una carga de trabajo de XPK, usa el siguiente comando:

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -${NUM_SLICES} \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES}  \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

El uso de --base-docker-image crea una nueva imagen de Docker con el directorio de trabajo actual integrado en el nuevo Docker.

Configura PyTorch con recursos en cola

Sigue estos pasos para instalar PyTorch con recursos en cola y ejecutar una pequeña secuencia de comandos en v6e.

Instala dependencias con SSH para acceder a las VMs

Para Multislice, agrega --node=all:

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Mejora el rendimiento de los modelos con asignaciones grandes y frecuentes

En el caso de los modelos que tienen asignaciones frecuentes y de tamaño, observamos que el uso de tcmalloc mejora el rendimiento de manera significativa en comparación con la implementación predeterminada de malloc, por lo que el malloc predeterminado que se usa en la VM de TPU es tcmalloc. Sin embargo, según tu carga de trabajo (por ejemplo, con DLRM, que tiene asignaciones muy grandes para sus tablas de incorporación), tcmalloc puede causar una ralentización, en cuyo caso puedes intentar restablecer la siguiente variable con el malloc predeterminado:

unset LD_PRELOAD

Usa una secuencia de comandos de Python para realizar un cálculo en la VM v6e:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

Esto genera un resultado similar al que se muestra a continuación:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

Configuración para TensorFlow

Para la versión preliminar pública v6e, solo se admite la versión del entorno de ejecución tf-nightly.

Para restablecer tpu-runtime con la versión compatible con TensorFlow v6e, ejecuta los siguientes comandos:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Usa SSH para acceder a worker-0:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

Instala TensorFlow en worker-0:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Exporta la variable de entorno TPU_NAME:

export TPU_NAME=v6e-16

Puedes ejecutar la siguiente secuencia de comandos de Python para verificar cuántos núcleos de TPU están disponibles en tu fragmento y probar que todo esté instalado correctamente (los resultados que se muestran se generaron con un fragmento v6e-16):

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

El resultado es similar a este:

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e con SkyPilot

Puedes usar TPU v6e con SkyPilot. Sigue estos pasos para agregar información de ubicación o precios relacionada con la v6e a SkyPilot.

  1. Agrega lo siguiente al final de ~/.sky/catalogs/v5/gcp/vms.csv :

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. Especifica los siguientes recursos en un archivo YAML:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. Inicia un clúster con TPU v6e:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. Conéctate a la TPU v6e con SSH: ssh tpu_v6

Instructivos de inferencia

En los siguientes instructivos, se muestra cómo ejecutar la inferencia en TPU v6e:

Ejemplos de entrenamiento

En las siguientes secciones, se proporcionan ejemplos para entrenar modelos de MaxText, MaxDiffusion y PyTorch en TPU v6e.

Entrenamiento de MaxText y MaxDiffusion en la VM de Cloud TPU v6e

En las siguientes secciones, se describe el ciclo de vida de entrenamiento de los modelos MaxText y MaxDiffusion.

En general, los pasos de alto nivel son los siguientes:

  1. Compila la imagen base de la carga de trabajo.
  2. Ejecuta tu carga de trabajo con XPK.
    1. Compila el comando de entrenamiento para la carga de trabajo.
    2. Implementa la carga de trabajo.
  3. Sigue la carga de trabajo y consulta las métricas.
  4. Borra la carga de trabajo de XPK si no la necesitas.
  5. Borra el clúster de XPK cuando ya no lo necesites.

Compila la imagen base

Instala MaxText o MaxDiffusion y compila la imagen de Docker:

  1. Clona el repositorio que deseas usar y cambia al directorio del repositorio:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Configura Docker para usar Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Compila la imagen de Docker con el siguiente comando o con la pila estable de JAX. Para obtener más información sobre JAX Stable Stack, consulta Cómo compilar una imagen de Docker con JAX Stable Stack.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.37
    
  4. Si inicias la carga de trabajo desde una máquina que no tiene la imagen compilada de forma local, sube la imagen:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
Compila una imagen de Docker con JAX Stable Stack

Puedes compilar las imágenes de Docker de MaxText y MaxDiffusion con la imagen base de la pila estable de JAX.

JAX Stable Stack proporciona un entorno coherente para MaxText y MaxDiffusion, ya que agrupa JAX con paquetes principales, como orbax, flax y optax, junto con una libtpu.so bien calificada que impulsa las utilidades del programa de TPU y otras herramientas esenciales. Estas bibliotecas se prueban para garantizar la compatibilidad y proporcionar una base estable para compilar y ejecutar MaxText y MaxDiffusion. Esto elimina los posibles conflictos debido a versiones de paquetes incompatibles.

La pila estable de JAX incluye una libtpu.so completamente publicada y calificada, la biblioteca principal que impulsa la compilación, la ejecución y la configuración de la red de ICI del programa de TPU. La versión de libtpu reemplaza la compilación nocturna que usaba JAX anteriormente y garantiza una funcionalidad coherente de los cálculos de XLA en TPU con pruebas de calificación a nivel de PJRT en IR de HLO/StableHLO.

Para compilar la imagen de Docker de MaxText y MaxDiffusion con la pila estable de JAX, cuando ejecutes la secuencia de comandos docker_build_dependency_image.sh, establece la variable MODE en stable_stack y la variable BASEIMAGE en la imagen base que deseas usar.

En el siguiente ejemplo, se especifica us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1 como la imagen base:

bash docker_build_dependency_image.sh MODE=stable_stack
BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1

Para obtener una lista de las imágenes base de JAX Stable Stack disponibles, consulta Imágenes de JAX Stable Stack en Artifact Registry.

Ejecuta tu carga de trabajo con XPK

  1. Establece las siguientes variables de entorno si no usas los valores predeterminados establecidos por MaxText o MaxDiffusion:

    export BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    export PER_DEVICE_BATCH_SIZE=2
    export NUM_STEPS=30
    export MAX_TARGET_LENGTH=8192
  2. Compila la secuencia de comandos de tu modelo. Esta secuencia de comandos se copiará como un comando de entrenamiento en un paso posterior.

    Aún no ejecutes la secuencia de comandos del modelo.

    MaxText

    MaxText es un LLM de código abierto, altamente escalable y de alto rendimiento, escrito en Python y JAX puros, y se orienta a Google Cloud TPU y GPUs para el entrenamiento y la inferencia.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=${BASE_OUTPUT_DIR} \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}  # attention='dot_product'"
    

    Gemma2

    Gemma es una familia de LLM con pesos abiertos que desarrolló Google DeepMind, basada en la investigación y la tecnología de Gemini.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral es un modelo de IA de vanguardia desarrollado por Mistral AI, que utiliza una arquitectura dispersa de mezcla de expertos (MoE).

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama es una familia de LLM de ponderación abierta que desarrolló Meta.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash"
    

    MaxDiffusion

    MaxDiffusion es una colección de implementaciones de referencia de varios modelos de difusión latentes escritos en Python y JAX puros que se ejecutan en dispositivos XLA, incluidas las GPUs y las Cloud TPU. Stable Diffusion es un modelo latente de texto a imagen que genera imágenes fotorrealistas a partir de cualquier entrada de texto.

    Debes instalar una rama específica de Git para ejecutar MaxDiffusion, como se muestra en el siguiente comando git checkout.

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    Secuencia de comandos de entrenamiento:

        cd maxdiffusion && OUT_DIR=${BASE_OUTPUT_DIR} \
        python src/maxdiffusion/train_sdxl.py \
        src/maxdiffusion/configs/base_xl.yml \
        revision=refs/pr/95 \
        activations_dtype=bfloat16 \
        weights_dtype=bfloat16 \
        resolution=1024 \
        per_device_batch_size=1 \
        output_dir=${OUT_DIR}  \
        jax_cache_dir=${OUT_DIR}/cache_dir/ \
        max_train_steps=200 \
        attention=flash run_name=sdxl-ddp-v6e
    
        
  3. Ejecuta el modelo con la secuencia de comandos que creaste en el paso anterior. Debes especificar la marca --base-docker-image para usar la imagen base de MaxText o especificar la marca --docker-image y la imagen que deseas usar.

    Opcional: Puedes habilitar el registro de depuración si incluyes la marca --enable-debug-logs. Para obtener más información, consulta Cómo depurar JAX en MaxText.

    Opcional: Puedes crear un experimento de Vertex AI para subir datos a Vertex AI TensorBoard si incluyes la marca --use-vertex-tensorboard. Para obtener más información, consulta Supervisa JAX en MaxText con Vertex AI.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
        --tpu-type=$ACCELERATOR_TYPE \
        --num-slices=$NUM_SLICES  \
        --on-demand \
        --zone $ZONE \
        --project $PROJECT_ID \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command $YOUR-MODEL-SCRIPT

    Exporta las siguientes variables:

    export ClUSTER_NAME=CLUSTER_NAME: Es el nombre de tu clúster de XPK. export ACCELERATOR_TYPEACCELERATOR_TYPE: Es la versión y el tamaño de tu TPU. Por ejemplo, v6e-256. export NUM_SLICES=NUM_SLICES: Es la cantidad de porciones de TPU. export YOUR_MODEL_SCRIPT=YOUR_MODEL_SCRIPT: Es la secuencia de comandos del modelo que se ejecutará como un comando de entrenamiento.

    El resultado incluye un vínculo para seguir tu carga de trabajo, similar al siguiente:

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    Abre el vínculo y haz clic en la pestaña Registros para hacer un seguimiento de tu carga de trabajo en tiempo real.

Cómo depurar JAX en MaxText

Usa comandos XPK complementarios para diagnosticar por qué no se ejecuta el clúster o la carga de trabajo.

Supervisa JAX en MaxText con Vertex AI

Consulta los datos escalares y de perfil a través de TensorBoard administrado de Vertex AI.

  1. Aumenta las solicitudes de administración de recursos (CRUD) de la zona que usas de 600 a 5,000. Esto podría no ser un problema para cargas de trabajo pequeñas que usan menos de 16 VMs.
  2. Instala dependencias como cloud-accelerator-diagnostics para Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Crea tu clúster de XPK con la marca --create-vertex-tensorboard, como se documenta en Crea Vertex AI TensorBoard. También puedes ejecutar este comando en clústeres existentes.

  4. Crea tu experimento de Vertex AI cuando ejecutes tu carga de trabajo de XPK con la marca --use-vertex-tensorboard y la marca opcional --experiment-name. Para obtener la lista completa de pasos, consulta Crea un experimento de Vertex AI para subir datos a Vertex AI TensorBoard.

Los registros incluyen un vínculo a un Vertex AI TensorBoard, similar al siguiente:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

También puedes encontrar el vínculo de Vertex AI TensorBoard en la console de Google Cloud. Ve a Experimentos de Vertex AI en la consola de Google Cloud. Selecciona la región adecuada en el menú desplegable.

El directorio de TensorBoard también se escribe en el bucket de Cloud Storage que especificaste con ${BASE_OUTPUT_DIR}.

Borra cargas de trabajo de XPK

Usa el comando xpk workload delete para borrar una o más cargas de trabajo según el prefijo o el estado del trabajo. Este comando puede ser útil si enviaste cargas de trabajo de XPK que ya no necesitan ejecutarse o si tienes trabajos que están detenidos en la cola.

Borra el clúster de XPK

Usa el comando xpk cluster delete para borrar un clúster:

python3 xpk.py cluster delete --cluster ${CLUSTER_NAME} \
--zone $ZONE --project $PROJECT_ID

Entrenamiento de Llama y PyTorch/XLA en la VM de Cloud TPU v6e

En este instructivo, se describe cómo entrenar modelos de Llama con PyTorch/XLA en TPU v6e con el conjunto de datos de WikiText.

Obtén acceso a Hugging Face y al modelo Llama 3

Necesitas un token de acceso de usuario de Hugging Face para ejecutar este instructivo. Para obtener información sobre cómo crear tokens de acceso de usuario, consulta la documentación de Hugging Face sobre los tokens de acceso de usuario.

También necesitas permiso para acceder al modelo Llama 3 8B en Hugging Face. Para obtener acceso, ve al modelo Meta-Llama-3-8B en Hugging Face y solicita acceso.

Crea una VM de TPU

Crea una TPU v6e con 8 chips para ejecutar el instructivo.

  1. Configure las variables de entorno:

    export ACCELERATOR_TYPE=v6e-8
    export VERSION=v2-alpha-tpuv6e
    export TPU_NAME=$USER-$ACCELERATOR_TYPE
    export PROJECT=YOUR_PROJECT
    export ZONE=YOUR_ZONE
  2. Crea una VM de TPU:

    gcloud alpha compute tpus tpu-vm create $TPU_NAME --version=$VERSION \
        --accelerator-type=$ACCELERATOR_TYPE --zone=$ZONE --project=$PROJECT

Instalación

Instala la división pytorch-tpu/transformers de Hugging Face Transformers y las dependencias. Este instructivo se probó con las siguientes versiones de dependencias que se usan en este ejemplo:

  • torch: Compatible con 2.5.0
  • torch_xla[tpu]: Compatible con 2.5.0
  • jax: 0.4.33
  • jaxlib: 0.4.33
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT --zone $ZONE \
    --worker=all --command='git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
    cd transformers
    sudo pip3 install -e .
    pip3 install datasets
    pip3 install evaluate
    pip3 install scikit-learn
    pip3 install accelerate
    pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
    pip install jax==0.4.33 jaxlib==0.4.33 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Configura los parámetros de configuración del modelo

El comando de entrenamiento en la siguiente sección, Run the model, usa dos archivos de configuración JSON para definir los parámetros del modelo y la configuración de FSDP (paralelismo de datos completamente fragmentado). El fragmentación de FSDP se usa para que los pesos del modelo se ajusten a un tamaño de lote más grande durante el entrenamiento. Cuando se entrena con modelos más pequeños, puede ser suficiente usar el paralelismo de datos y replicar los pesos en cada dispositivo. Para obtener más información sobre cómo dividir tensores en varios dispositivos en PyTorch/XLA, consulta la Guía del usuario de SPMD de PyTorch/XLA.

  1. Crea el archivo de configuración del parámetro del modelo. La siguiente es la configuración del parámetro del modelo para Llama3-8B. Para otros modelos, busca la configuración en Hugging Face. Por ejemplo, consulta la configuración de Llama2-7B.

    cat > llama-config.json <
    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
    EOF
  2. Crea el archivo de configuración de FSDP:

    cat > fsdp-config.json <
    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }
    EOF

    Para obtener más información sobre FSDP, consulta FSDPv2.

  3. Sube los archivos de configuración a tus VMs de TPU con el siguiente comando:

    gcloud alpha compute tpus tpu-vm scp llama-config.json fsdp-config.json $TPU_NAME:. \
        --worker=all \
        --project=$PROJECT \
        --zone $ZONE

Ejecuta el modelo

Con los archivos de configuración que creaste en la sección anterior, ejecuta la secuencia de comandos run_clm.py para entrenar el modelo Llama 3 8B en el conjunto de datos de WikiText. La secuencia de comandos de entrenamiento tarda aproximadamente 10 minutos en ejecutarse en una TPU v6e-8.

  1. Accede a Hugging Face en tu TPU con el siguiente comando:

    gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
        --zone $ZONE \
        --worker=all \
        --command='
        pip3 install "huggingface_hub[cli]"
        huggingface-cli login --token HUGGING_FACE_TOKEN'
  2. Ejecuta el entrenamiento de modelos:

    gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
        --zone $ZONE \
        --worker=all \
        --command='
        export PJRT_DEVICE=TPU
        export XLA_USE_SPMD=1
        export ENABLE_PJRT_COMPATIBILITY=true
            # Optional variables for debugging:
        export XLA_IR_DEBUG=1
        export XLA_HLO_DEBUG=1
        export PROFILE_EPOCH=0
        export PROFILE_STEP=3
        export PROFILE_DURATION_MS=100000
            # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path
        export PROFILE_LOGDIR=PROFILE_PATH
        python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 16 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/llama-config.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp-config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20'

Soluciona problemas de PyTorch/XLA

Si configuras las variables opcionales para la depuración en la sección anterior, el perfil del modelo se almacenará en la ubicación que especifique la variable PROFILE_LOGDIR. Puedes extraer el archivo xplane.pb almacenado en esta ubicación y usar tensorboard para ver los perfiles en tu navegador con las instrucciones de TensorBoard. Si PyTorch/XLA no funciona como se espera, consulta la guía de solución de problemas, que tiene sugerencias para depurar, generar perfiles y optimizar tu modelo.

Entrenamiento de DLRM DCN v2 en v6e

En este instructivo, se muestra cómo entrenar el modelo DLRM DCN v2 en TPU v6e. Debes aprovisionar una TPU v6e con 64, 128 o 256 chips.

Si ejecutas en varios hosts, ejecuta el siguiente comando para restablecer tpu-runtime con la versión correcta de TensorFlow: Si ejecutas la app en un solo host, no necesitas ejecutar los siguientes dos comandos.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Establece una conexión SSH a worker-0

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

Establece el nombre de la TPU

export TPU_NAME=${TPU_NAME}

Ejecuta DLRM v2

pip install --user setuptools==65.5.0

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

Ejecuta script.sh:

chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Las siguientes marcas son necesarias para ejecutar cargas de trabajo de recomendación (DCN de DLRM):

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

Resultados de comparativas

En la siguiente sección, se incluyen los resultados de las comparativas de DLRM DCN v2 y MaxDiffusion en v6e.

DLRM DCN v2

La secuencia de comandos de entrenamiento de DLRM DCN v2 se ejecutó a diferentes escalas. Consulta las tasas de transferencia en la siguiente tabla.

v6e-64 v6e-128 v6e-256
Pasos de entrenamiento 7000 7000 7000
Tamaño del lote global 131072 262144 524288
Capacidad de procesamiento (ejemplos/s) 2975334 5111808 10066329

MaxDiffusion

Ejecutamos la secuencia de comandos de entrenamiento para MaxDiffusion en un v6e-4, un v6e-16 y un 2xv6e-16. Consulta las tasas de transferencia en la siguiente tabla.

v6e-4 v6e-16 Dos v6e-16
Pasos de entrenamiento 0.069 0.073 0.13
Tamaño del lote global 8 32 64
Capacidad de procesamiento (ejemplos/s) 115.9 438.4 492.3

Programación de la recopilación

Trillium (v6e) incluye una nueva función llamada "programación de colecciones". Esta función ofrece una forma de administrar varias porciones de TPU que ejecutan una carga de trabajo de inferencia de host único en GKE y la API de Cloud TPU. Agrupar estas porciones en una colección facilita el ajuste de la cantidad de réplicas para que coincidan con la demanda. Las actualizaciones de software se controlan con cuidado para garantizar que una parte de las divisiones dentro de la colección siempre esté disponible para controlar el tráfico entrante.

Consulta la documentación de GKE para obtener más información sobre el uso de la programación de colecciones con GKE.

La función de programación de colecciones solo se aplica a la versión 6e.

Usa la programación de colecciones desde la API de Cloud TPU

Una colección de un solo host en la API de Cloud TPU es un recurso en cola en el que se establece una marca especial (--workload-type = availability-optimized) para indicarle a la infraestructura subyacente que se debe usar para entregar cargas de trabajo.

El siguiente comando aprovisiona una colección de host único con la API de Cloud TPU:

gcloud alpha compute tpus queued-resources create my-collection \
   --project=$PROJECT_ID \
   --zone=${ZONE} \
   --accelerator-type $ACCELERATOR_TYPE \
   --node-count ${NODE_COUNT} \
   --workload-type=availability-optimized

Supervisa y crea perfiles

Cloud TPU v6e admite la supervisión y la generación de perfiles con los mismos métodos que las generaciones anteriores de Cloud TPU. Para obtener más información sobre la supervisión, consulta Supervisa VMs de TPU.