Introducción a Trillium (v6e)

En esta documentación, la API de TPU y los registros, se usa v6e para referirse a Trillium. v6e representa la 6ª generación de TPU de Google.

Con 256 chips por Pod, la v6e comparte muchas similitudes con la v5e. Este sistema está optimizado para ser el producto de mayor valor para el entrenamiento, la optimización y la publicación de transformadores, texto a imagen y redes neuronales convolucionales (CNN).

Arquitectura del sistema v6e

Para obtener información sobre la configuración de Cloud TPU, consulta la documentación de v6e.

En este documento, se enfoca en el proceso de configuración para el entrenamiento de modelos con los frameworks de JAX, PyTorch o TensorFlow. Con cada framework, puedes aprovisionar TPU con recursos en cola o Google Kubernetes Engine (GKE). La configuración de GKE se puede realizar con comandos de XPK o de GKE.

Prepara un proyecto de Google Cloud

  1. Accede a tu Cuenta de Google. Si aún no lo hiciste, regístrate para obtener una nueva cuenta.
  2. En la consola de Google Cloud, selecciona o crea un proyecto de Cloud en la página del selector de proyectos.
  3. Habilita la facturación para tu proyecto de Google Cloud. La facturación es obligatoria para todo el uso de Google Cloud.
  4. Instala los componentes de gcloud alpha.
  5. Ejecuta el siguiente comando para instalar la versión más reciente de los componentes de gcloud.

    gcloud components update
    
  6. Habilita la API de TPU con el siguiente comando gcloud en Cloud Shell. También puedes habilitarlo desde la consola de Google Cloud.

    gcloud services enable tpu.googleapis.com
    
  7. Habilita los permisos con la cuenta de servicio de TPU para la API de Compute Engine

    Las cuentas de servicio permiten que el servicio de Cloud TPU acceda a otros servicios de Google Cloud. Una cuenta de servicio administrada por el usuario es una práctica recomendada de Google Cloud. Sigue estas guías para crear y otorgar roles. Se requieren los siguientes roles:

    • Administrador de TPU
    • Administrador de almacenamiento
    • Escritor de registros
    • Escritor de métricas de Monitoring

    a. Configura los permisos de XPK con tu cuenta de usuario de GKE: XPK.

  8. Crea variables de entorno para el ID y la zona del proyecto.

     gcloud auth login
     gcloud config set project ${PROJECT_ID}
     gcloud config set compute/zone ${ZONE}
    
  9. Crea una identidad de servicio para la VM de TPU.

     gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
    

Capacidad segura

Comunícate con el equipo de asistencia de ventas o de cuentas de Cloud TPU para solicitar la cuota de TPU y responder cualquier pregunta sobre la capacidad.

Aprovisiona el entorno de Cloud TPU

Las TPU v6e se pueden aprovisionar y administrar con GKE, con GKE y XPK (una herramienta de wrapper de CLI sobre GKE) o como recursos en fila.

Requisitos previos

  • Verifica que tu proyecto tenga suficiente cuota de TPUS_PER_TPU_FAMILY, que especifica la cantidad máxima de chips a los que puedes acceder en tu proyecto de Google Cloud.
  • La versión 6e se probó con la siguiente configuración:
    • Python 3.10 o versiones posteriores
    • Versiones de software nocturnas:
      • JAX nocturno 0.4.32.dev20240912
      • LibTPU nocturna 0.1.dev20240912+nightly
    • Versiones de software estables:
      • JAX + JAX Lib de la versión 0.4.35
  • Verifica que tu proyecto tenga suficiente cuota de TPU para lo siguiente:
    • Cuota de VM de TPU
    • Cuota de direcciones IP
    • Quota de Hyperdisk-balance
  • Permisos del proyecto del usuario

Variables de entorno

En Cloud Shell, crea las siguientes variables de entorno:

export NODE_ID=TPU_NODE_ID # TPU name
export PROJECT_ID=PROJECT_ID
export ACCELERATOR_TYPE=v6e-16
export ZONE=us-central2-b
export RUNTIME_VERSION=v2-alpha-tpuv6e
export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT
export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID
export VALID_DURATION=VALID_DURATION

# Additional environment variable needed for Multislice:
export NUM_SLICES=NUM_SLICES

# Use a custom network for better performance as well as to avoid having the
# default network becoming overloaded.
export NETWORK_NAME=${PROJECT_ID}-mtu9k
export NETWORK_FW_NAME=${NETWORK_NAME}-fw

Descripciones de las marcas de comandos

Variable Descripción
NODE_ID El ID asignado por el usuario de la TPU que se crea cuando se asigna la solicitud de recurso en fila.
ID DEL PROYECTO Nombre del proyecto de Google Cloud. Usa un proyecto existente o crea uno nuevo en
ZONA Consulta el documento Regiones y zonas de TPU para conocer las zonas compatibles.
ACCELERATOR_TYPE Consulta Tipos de aceleradores.
RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Esta es la dirección de correo electrónico de tu cuenta de servicio que puedes encontrar en la Google Cloud Console -> IAM -> Cuentas de servicio.

Por ejemplo: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

NUM_SLICES Es la cantidad de rebanadas que se deben crear (solo es necesaria para Multislice).
QUEUED_RESOURCE_ID El ID de texto asignado por el usuario de la solicitud de recursos en cola.
VALID_DURATION Es la duración durante la cual la solicitud de recursos en cola es válida.
NETWORK_NAME Es el nombre de una red secundaria que se usará.
NETWORK_FW_NAME Es el nombre de un firewall de red secundario que se usará.

Optimizaciones del rendimiento de la red

Para obtener el mejor rendimiento, usa una red con 8,896 MTU (unidad de transmisión máxima).

De forma predeterminada, una nube privada virtual (VPC) solo proporciona una MTU de 1,460 bytes, lo que proporcionará un rendimiento de red poco óptimo. Puedes configurar la MTU de una red de VPC en cualquier valor entre 1,300 bytes y 8,896 bytes (inclusive). Los tamaños de MTU personalizados comunes son 1,500 bytes (Ethernet estándar) o 8,896 bytes (el máximo posible). Para obtener más información, consulta Tamaños válidos de MTU de la red de VPC.

Para obtener más información sobre cómo cambiar la configuración de MTU de una red existente o predeterminada, consulta Cambia la configuración de MTU de una red de VPC.

En el siguiente ejemplo, se crea una red con 8,896 MTU.

export RESOURCE_NAME=RESOURCE_NAME
export NETWORK_NAME=${RESOURCE_NAME}
export NETWORK_FW_NAME=${RESOURCE_NAME}
export PROJECT=X
gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional
gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} \

Cómo usar varias NIC (opción para Multislice)

Las siguientes variables de entorno son necesarias para una subred secundaria cuando usas un entorno de Multislice.

export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2

Usa los siguientes comandos para crear enrutamiento de IP personalizado para la red y la subred.

gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
   --bgp-routing-mode=regional --subnet-mode=custom --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
   --network="${NETWORK_NAME_2}" \
   --range=10.10.0.0/18 --region="${REGION}" \
   --project=$PROJECT

gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
   --network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
   --source-ranges 10.10.0.0/18 --project="${PROJECT}"

gcloud compute routers create "${ROUTER_NAME}" \
  --project="${PROJECT}" \
  --network="${NETWORK_NAME_2}" \
  --region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
  --router="${ROUTER_NAME}" \
  --region="${REGION}" \
  --auto-allocate-nat-external-ips \
  --nat-all-subnet-ip-ranges \
  --project="${PROJECT}" \
  --enable-logging

Una vez que se crea una porción de varias redes, puedes validar que se usen ambas NIC ejecutando --command ifconfig como parte de la carga de trabajo de XPK. Luego, observa el resultado impreso de esa carga de trabajo de XPK en los registros de la consola de Cloud y verifica que tanto eth0 como eth1 tengan mtu=8896.

python3 xpk.py workload create \
   --cluster ${CLUSTER_NAME} \
   (--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}) \
   --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \
   --tpu-type=${ACCELERATOR_TYPE} \
   --num-slices=${NUM_SLICES}  \
   --on-demand \
   --zone $ZONE \
   --project $PROJECT_ID \
   [--enable-debug-logs] \
   [--use-vertex-tensorboard] \
   --command "ifconfig"

Verifica que tanto eth0 como eth1 tengan mtu=8,896. Una forma de verificar que tienes varios NIC en ejecución es ejecutar el comando --command "ifconfig" como parte de la carga de trabajo de XPK. Luego, observa el resultado impreso de esa carga de trabajo de xpk en los registros de la consola de Cloud y verifica que tanto eth0 como eth1 tengan mtu=8896.

Configuración de TCP mejorada

En el caso de las TPU creadas con la interfaz de recursos en cola, puedes ejecutar el siguiente comando para mejorar el rendimiento de la red cambiando la configuración predeterminada de TCP para rto_min y quickack.

gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \
   --project "$PROJECT" --zone "${ZONE}" \
   --command='ip route show | while IFS= read -r route; do if ! echo $route | \
   grep -q linkdown; then sudo ip route change ${route/lock/} rto_min 5ms quickack 1; fi; done' \
   --worker=all

Aprovisionamiento con recursos en cola (API de Cloud TPU)

La capacidad se puede aprovisionar con el comando create de recursos en cola.

  1. Crea una solicitud de recurso en cola de TPU.

    La marca --reserved solo es necesaria para los recursos reservados, no para los recursos on demand.

    gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \
      --node-id ${TPU_NAME} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --accelerator-type ${ACCELERATOR_TYPE} \
      --runtime-version ${RUNTIME_VERSION} \
      --valid-until-duration ${VALID_DURATION} \
      --service-account ${SERVICE_ACCOUNT} \
      [--reserved]

    Si la solicitud de recursos en cola se crea correctamente, el estado en el campo "response" será "WAITING_FOR_RESOURCES" o "FAILED". Si la solicitud de recursos en cola está en el estado "WAITING_FOR_RESOURCES", significa que el recurso en cola se agregó a la cola y se aprovisionará cuando haya suficiente capacidad de TPU. Si la solicitud de recursos en cola está en el estado "FAILED", el motivo de la falla aparecerá en el resultado. La solicitud de recursos en cola vencerá si no se aprovisiona un v6e dentro de la duración especificada y el estado se convierte en "FAILED". Consulta la documentación pública de Recursos en cola para obtener más información.

    Cuando tu solicitud de recursos en cola esté en el estado "ACTIVO", podrás conectarte a tus VMs de TPU con SSH. Usa los comandos list o describe para consultar el estado de tu recurso en cola.

    gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID}  \
       --project ${PROJECT_ID} --zone ${ZONE}
    

    Cuando el recurso en cola está en el estado "ACTIVE", el resultado es similar al siguiente:

      state:
       state: ACTIVE
    
  2. Administra tus VMs de TPU. Para conocer las opciones de administración de tus VMs de TPU, consulta Cómo administrar VMs de TPU.

  3. Conéctate a tus VMs de TPU con SSH

    Puedes instalar objetos binarios en cada VM de TPU de tu porción de TPU y ejecutar código. Consulta la sección Tipos de VM para determinar cuántas VMs tendrá tu fragmento.

    Para instalar los objetos binarios o ejecutar código, puedes usar SSH para conectarte a una VM con el comando tpu-vm ssh.

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
       --node=all # add this flag if you are using Multislice
    

    Para usar SSH y conectarte a una VM específica, usa la marca --worker que sigue a un índice basado en 0:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
    

    Si tienes formas de rebanada superiores a 8 chips, tendrás varias VM en una rebanada. En este caso, usa los parámetros --worker=all y --command en el comando gcloud alpha compute tpus tpu-vm ssh para ejecutar un comando en todas las VMs de forma simultánea. Por ejemplo:

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
      --zone  ${ZONE} --worker=all \
      --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    
  4. Borrar un recurso en cola

    Borra un recurso en cola al final de la sesión o quita las solicitudes de recursos en cola que estén en el estado "FAILED". Para borrar un recurso en cola, borra la porción y, luego, la solicitud de recurso en cola en 2 pasos:

    gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \
     --zone=${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
     --project ${PROJECT_ID} --zone ${ZONE} --quiet
    
    gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
    

Usa GKE con v6e

Si usas comandos de GKE con v6e, puedes usar comandos de Kubernetes o XPK para aprovisionar TPU y entrenar o entregar modelos. Consulta Planifica el uso de TPU en GKE para aprender a usar GKE con TPU y v6e.

Configuración del framework

En esta sección, se describe el proceso de configuración general para el entrenamiento de modelos de AA con los frameworks JAX, PyTorch o TensorFlow. Puedes aprovisionar TPUs con recursos en cola o GKE. La configuración de GKE se puede realizar con XPK o comandos de Kubernetes.

Configura JAX con recursos en cola

Instala JAX en todas las VMs de TPU de tu porción o porciones de forma simultánea con gcloud alpha compute tpus tpu-vm ssh. Para Multislice, agrega --node=all.


gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
 --zone ${ZONE} --worker=all \
 --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'

Puedes ejecutar el siguiente código de Python para verificar cuántos núcleos de TPU están disponibles en tu fragmento y probar que todo esté instalado correctamente (los resultados que se muestran aquí se produjeron con un fragmento v6e-16):

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
   --zone ${ZONE} --worker=all  \
   --command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'

El resultado es similar a este:

SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4

jax.device_count() muestra la cantidad total de chips en la porción determinada. jax.local_device_count() indica la cantidad de chips a los que puede acceder una sola VM en esta porción.

gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
   --project=${PROJECT_ID} --zone=${ZONE} --worker=all  \
   --command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
   cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 &&
   && pip install -r requirements.txt  && pip install . '

Cómo solucionar problemas de configuración de JAX

Una sugerencia general es habilitar el registro detallado en el manifiesto de tu carga de trabajo de GKE. Luego, proporciona los registros al equipo de asistencia de GKE.

TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0

Mensajes de error

no endpoints available for service 'jobset-webhook-service'

Este error significa que el conjunto de trabajos no se instaló correctamente. Verifica si se están ejecutando los pods de Kubernetes de la implementación de jobset-controller-manager. Para obtener más información, consulta la documentación de solución de problemas de JobSet.

TPU initialization failed: Failed to connect

Asegúrate de que la versión de tu nodo de GKE sea 1.30.4-gke.1348000 o posterior (no se admite GKE 1.31).

Configuración para PyTorch

En esta sección, se describe cómo comenzar a usar PJRT en v6e con PyTorch/XLA. La versión recomendada de Python es 3.10.

Configura PyTorch con GKE y XPK

Puedes usar el siguiente contenedor de Docker con XPK, que ya tiene instaladas las dependencias de PyTorch:

us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028

Para crear una carga de trabajo de XPK, usa el siguiente comando:

python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -$NUM_SLICES \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES}  \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'

El uso de --base-docker-image crea una nueva imagen de Docker con el directorio de trabajo actual integrado en el nuevo Docker.

Configura PyTorch con recursos en cola

Sigue estos pasos para instalar PyTorch con recursos en cola y ejecutar una pequeña secuencia de comandos en v6e.

Instala dependencias con SSH para acceder a las VMs.

Para Multislice, agrega --node=all:

   gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
    --project=${PROJECT_ID} \
    --zone=${ZONE} \
    --worker=all \
    --command='sudo apt install -y libopenblas-base pip3 \
    install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
    --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'

Mejora el rendimiento de los modelos con asignaciones grandes y frecuentes

En el caso de los modelos que tienen asignaciones frecuentes y de tamaño, observamos que el uso de tcmalloc mejora el rendimiento de manera significativa en comparación con la implementación predeterminada de malloc, por lo que el malloc predeterminado que se usa en la VM de TPU es tcmalloc. Sin embargo, según tu carga de trabajo (por ejemplo, con DLRM, que tiene asignaciones muy grandes para sus tablas de incorporación), tcmalloc puede causar una ralentización, en cuyo caso puedes intentar restablecer la siguiente variable con el malloc predeterminado:

unset LD_PRELOAD

Usa una secuencia de comandos de Python para realizar un cálculo en la VM v6e:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
   --project ${PROJECT_ID} \
   --zone ${ZONE} --worker all --command='
   unset LD_PRELOAD
   python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'

Esto genera un resultado similar al que se muestra a continuación:

SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
        [-1.4656,  0.3196, -2.8766],
        [ 0.8668, -1.5060,  0.7125]], device='xla:0')

Configuración para TensorFlow

Para la versión preliminar pública de v6e, solo se admite la versión del entorno de ejecución tf-nightly.

Para restablecer tpu-runtime con la versión compatible con TensorFlow v6e, ejecuta los siguientes comandos:

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
    --zone  ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
    --zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Usa SSH para acceder a worker-0:

$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
     --zone ${ZONE}

Instala TensorFlow en worker-0:

sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Exporta la variable de entorno TPU_NAME:

export TPU_NAME=v6e-16

Puedes ejecutar la siguiente secuencia de comandos de Python para verificar cuántos núcleos de TPU están disponibles en tu fragmento y probar que todo esté instalado correctamente (los resultados que se muestran se generaron con un fragmento v6e-16):

import TensorFlow as tf
print("TensorFlow version " + tf.__version__)

@tf.function
  def add_fn(x,y):
  z = x + y
  return z

  cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
  tf.config.experimental_connect_to_cluster(cluster_resolver)
  tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
  strategy = tf.distribute.TPUStrategy(cluster_resolver)

  x = tf.constant(1.)
  y = tf.constant(1.)
  z = strategy.run(add_fn, args=(x,y))
  print(z)

El resultado es similar a este:

PerReplica:{
  0: tf.Tensor(2.0, shape=(), dtype=float32),
  1: tf.Tensor(2.0, shape=(), dtype=float32),
  2: tf.Tensor(2.0, shape=(), dtype=float32),
  3: tf.Tensor(2.0, shape=(), dtype=float32),
  4: tf.Tensor(2.0, shape=(), dtype=float32),
  5: tf.Tensor(2.0, shape=(), dtype=float32),
  6: tf.Tensor(2.0, shape=(), dtype=float32),
  7: tf.Tensor(2.0, shape=(), dtype=float32)
}

v6e con SkyPilot

Puedes usar TPU v6e con SkyPilot. Sigue estos pasos para agregar información de ubicación o precios relacionada con la v6e a SkyPilot.

  1. Agrega lo siguiente al final de ~/.sky/catalogs/v5/gcp/vms.csv :

    ,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0
    ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
    
  2. Especifica los siguientes recursos en un archivo YAML:

    # tpu_v6.yaml
    resources:
      accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use
      accelerator_args:
        runtime_version: v2-alpha-tpuv6e # Official suggested runtime
    
  3. Inicia un clúster con TPU v6e:

       sky launch tpu_v6.yaml -c tpu_v6
    
  4. Conéctate a la TPU v6e con SSH: ssh tpu_v6

Instructivos de inferencia

En las siguientes secciones, se proporcionan instructivos para la entrega de modelos de MaxText y PyTorch con JetStream, así como la entrega de modelos de MaxDiffusion en TPU v6e.

MaxText en JetStream

En este instructivo, se muestra cómo usar JetStream para entregar modelos de MaxText (JAX) en TPU v6e. JetStream es un motor con capacidad de procesamiento y memoria optimizada para la inferencia de modelos de lenguaje grandes (LLM) en dispositivos XLA (TPU). En este instructivo, ejecutarás la comparativa de inferencia para el modelo Llama2-7B.

Antes de comenzar

  1. Crea una TPU v6e con 4 chips:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Conéctate a la TPU con SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Ejecuta el instructivo

Para configurar JetStream y MaxText, convertir los puntos de control del modelo y ejecutar la comparativa de inferencia, sigue las instrucciones en el repositorio de GitHub.

Limpia

Borra la TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

vLLM en PyTorch TPU

A continuación, se incluye un instructivo sencillo que muestra cómo comenzar a usar vLLM en una VM de TPU. En nuestro ejemplo de prácticas recomendadas para implementar vLLM en Trillium en producción, publicaremos una guía del usuario de GKE en los próximos días (¡no te pierdas las novedades!).

Antes de comenzar

  1. Crea una TPU v6e con 4 chips:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
       --node-id TPU_NAME \
       --project PROJECT_ID \
       --zone ZONE \
       --accelerator-type v6e-4 \
       --runtime-version v2-alpha-tpuv6e \
       --service-account SERVICE_ACCOUNT

    Descripciones de las marcas de comandos

    Variable Descripción
    NODE_ID El ID asignado por el usuario de la TPU que se crea cuando se asigna la solicitud de recurso en fila.
    ID DEL PROYECTO Nombre del proyecto de Google Cloud. Usa un proyecto existente o crea uno nuevo en .
    ZONA Consulta el documento Regiones y zonas de TPU para conocer las zonas compatibles.
    ACCELERATOR_TYPE Consulta Tipos de aceleradores.
    RUNTIME_VERSION v2-alpha-tpuv6e
    SERVICE_ACCOUNT Esta es la dirección de correo electrónico de tu cuenta de servicio que puedes encontrar en la Google Cloud Console -> IAM -> Cuentas de servicio.

    Por ejemplo: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com

  2. Conéctate a la TPU con SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME
    

Create a Conda environment

  1. (Recommended) Create a new conda environment for vLLM:

    conda create -n vllm python=3.10 -y
    conda activate vllm

Configura vLLM en TPU

  1. Clona el repositorio de vLLM y navega al directorio vLLM:

    git clone https://github.com/vllm-project/vllm.git && cd vllm
    
  2. Limpia los paquetes torch y torch-xla existentes:

    pip uninstall torch torch-xla -y
    
  3. Instala PyTorch y PyTorch XLA:

    pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
    pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
    
  4. Instala JAX y Pallas:

    pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
    
    
  5. Instala otras dependencias de compilación:

    pip install -r requirements-tpu.txt
    VLLM_TARGET_DEVICE="tpu" python setup.py develop
    sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
    

Obtén acceso al modelo

Debes firmar el acuerdo de consentimiento para usar la familia de modelos Llama3 en el repositorio de HuggingFace.

Genera un nuevo token de Hugging Face si aún no tienes uno:

  1. Haz clic en Tu perfil > Configuración > Tokens de acceso.
  2. Selecciona Token nuevo.
  3. Especifica el nombre que desees y un rol de al menos Read.
  4. Selecciona Generate un token.
  5. Copia el token generado en el portapapeles, configúralo como una variable de entorno y realiza la autenticación con huggingface-cli:

    export TOKEN=''
    git config --global credential.helper store
    huggingface-cli login --token $TOKEN

Descargar datos de comparativas

  1. Crea un directorio /data y descarga el conjunto de datos de ShareGPT desde Hugging Face.

    mkdir ~/data && cd ~/data
    wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
    

Inicia el servidor vLLM

El siguiente comando descarga los pesos del modelo de Hugging Face Model Hub al directorio /tmp de la VM de TPU, precompila un rango de formas de entrada y escribe la compilación del modelo en ~/.cache/vllm/xla_cache.

Para obtener más detalles, consulta la documentación de vLLM.

   cd ~/vllm
   vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &

Ejecuta comparativas de vLLM

Ejecuta la secuencia de comandos de comparativas de vLLM:

   python benchmarks/benchmark_serving.py \
       --backend vllm \
       --model "meta-llama/Meta-Llama-3.1-8B"  \
       --dataset-name sharegpt \
       --dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json  \
       --num-prompts 1000

Limpia

Borra la TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

PyTorch en JetStream

En este instructivo, se muestra cómo usar JetStream para entregar modelos de PyTorch en TPU v6e. JetStream es un motor con capacidad de procesamiento y memoria optimizada para la inferencia de modelos de lenguaje grandes (LLM) en dispositivos XLA (TPU). En este instructivo, ejecutarás la comparativa de inferencia para el modelo Llama2-7B.

Antes de comenzar

  1. Crea una TPU v6e con 4 chips:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Conéctate a la TPU con SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Ejecuta el instructivo

Para configurar JetStream-PyTorch, convertir los puntos de control del modelo y ejecutar la comparativa de inferencia, sigue las instrucciones en el repositorio de GitHub.

Limpia

Borra la TPU:

   gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
      --project ${PROJECT_ID} \
      --zone ${ZONE} \
      --force \
      --async

Inferencia de MaxDiffusion

En este instructivo, se muestra cómo entregar modelos de MaxDiffusion en TPU v6e. En este instructivo, generarás imágenes con el modelo Stable Diffusion XL.

Antes de comenzar

  1. Crea una TPU v6e con 4 chips:

    gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \
        --node-id TPU_NAME \
        --project PROJECT_ID \
        --zone ZONE \
        --accelerator-type v6e-4 \
        --runtime-version v2-alpha-tpuv6e \
        --service-account SERVICE_ACCOUNT
  2. Conéctate a la TPU con SSH:

    gcloud compute tpus tpu-vm ssh TPU_NAME

Crea un entorno de Conda

  1. Crea un directorio para Miniconda:

    mkdir -p ~/miniconda3
  2. Descarga la secuencia de comandos del instalador de Miniconda:

    wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
  3. Instala Miniconda:

    bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
  4. Quita la secuencia de comandos del instalador de Miniconda:

    rm -rf ~/miniconda3/miniconda.sh
  5. Agrega Miniconda a tu variable PATH:

    export PATH="$HOME/miniconda3/bin:$PATH"
  6. Vuelve a cargar ~/.bashrc para aplicar los cambios a la variable PATH:

    source ~/.bashrc
  7. Crea un nuevo entorno de Conda:

    conda create -n tpu python=3.10
  8. Activa el entorno de Conda:

    source activate tpu

Configura MaxDiffusion

  1. Clona el repositorio de MaxDiffusion y navega al directorio MaxDiffusion:

    https://github.com/google/maxdiffusion.git && cd maxdiffusion
  2. Cambia a la rama mlperf-4.1:

    git checkout mlperf4.1
  3. Instala MaxDiffusion:

    pip install -e .
  4. Instala las dependencias:

    pip install -r requirements.txt
  5. Instala JAX:

    pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Generar imágenes

  1. Establece variables de entorno para configurar el entorno de ejecución de TPU:

    LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
  2. Genera imágenes con la instrucción y las configuraciones definidas en src/maxdiffusion/configs/base_xl.yml:

    python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"

Limpia

Borra la TPU:

gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \
    --project PROJECT_ID \
    --zone ZONE \
    --force \
    --async

Instructivos de capacitación

En las siguientes secciones, se proporcionan instructivos para entrenar MaxText.

Modelos de MaxDiffusion y PyTorch en TPU v6e

MaxText y MaxDiffusion

En las siguientes secciones, se describe el ciclo de vida de entrenamiento de los modelos MaxText y MaxDiffusion.

En general, los pasos de alto nivel son los siguientes:

  1. Compila la imagen base de la carga de trabajo.
  2. Ejecuta tu carga de trabajo con XPK.
    1. Compila el comando de entrenamiento para la carga de trabajo.
    2. Implementa la carga de trabajo.
  3. Sigue la carga de trabajo y consulta las métricas.
  4. Borra la carga de trabajo de XPK si no la necesitas.
  5. Borra el clúster de XPK cuando ya no lo necesites.

Compila la imagen base

Instala MaxText o MaxDiffusion y compila la imagen de Docker:

  1. Clona el repositorio que deseas usar y cambia al directorio del repositorio:

    MaxText:

    git clone https://github.com/google/maxtext.git && cd maxtext
    

    MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
    
  2. Configura Docker para usar Google Cloud CLI:

    gcloud auth configure-docker
    
  3. Compila la imagen de Docker con el siguiente comando o con la pila estable de JAX. Para obtener más información sobre JAX Stable Stack, consulta Cómo compilar una imagen de Docker con JAX Stable Stack.

    bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
    
  4. Si inicias la carga de trabajo desde una máquina que no tiene la imagen compilada de forma local, sube la imagen:

    bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
    
Compila una imagen de Docker con JAX Stable Stack

Puedes compilar las imágenes de Docker de MaxText y MaxDiffusion con la imagen base de la pila estable de JAX.

La pila estable de JAX proporciona un entorno coherente para MaxText y MaxDiffusion, ya que agrupa JAX con paquetes principales, como orbax, flax y optax, junto con una libtpu.so bien calificada que impulsa las utilidades del programa de TPU y otras herramientas esenciales. Estas bibliotecas se prueban para garantizar la compatibilidad, lo que proporciona una base estable para compilar y ejecutar MaxText y MaxDiffusion, y eliminar posibles conflictos debido a versiones de paquetes incompatibles.

La pila estable de JAX incluye una libtpu.so completamente publicada y calificada, la biblioteca principal que impulsa la compilación, la ejecución y la configuración de la red de ICI del programa de TPU. La versión de libtpu reemplaza la compilación nocturna que usaba JAX anteriormente y garantiza una funcionalidad coherente de los cálculos de XLA en TPU con pruebas de calificación a nivel de PJRT en IR de HLO/StableHLO.

Para compilar la imagen de Docker de MaxText y MaxDiffusion con la pila estable de JAX, cuando ejecutes la secuencia de comandos docker_build_dependency_image.sh, establece la variable MODE en stable_stack y la variable BASEIMAGE en la imagen base que deseas usar.

En el siguiente ejemplo, se especifica us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 como la imagen base:

bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1

Para obtener una lista de las imágenes base de JAX Stable Stack disponibles, consulta Imágenes de JAX Stable Stack en Artifact Registry.

Ejecuta tu carga de trabajo con XPK

  1. Configura las siguientes variables de entorno si no usas los valores predeterminados que establece MaxText o MaxDiffusion:

    BASE_OUTPUT_DIR=gs://YOUR_BUCKET
    PER_DEVICE_BATCH_SIZE=2
    NUM_STEPS=30
    MAX_TARGET_LENGTH=8192
  2. Compila la secuencia de comandos del modelo para que se copie como un comando de entrenamiento en el siguiente paso. Aún no ejecutes la secuencia de comandos del modelo.

    MaxText

    MaxText es un LLM de código abierto, altamente escalable y de alto rendimiento, escrito en Python y JAX puros, y se orienta a las TPU y las GPUs de Google Cloud para el entrenamiento y la inferencia.

    JAX_PLATFORMS=tpu,cpu \
    ENABLE_PJRT_COMPATIBILITY=true \
    TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \
    TPU_SLICE_BUILDER_DUMP_ICI=true && \
    python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \
            base_output_directory=$BASE_OUTPUT_DIR \
            dataset_type=synthetic \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            enable_checkpointing=false \
            gcs_metrics=true \
            profiler=xplane \
            skip_first_n_steps_for_profiler=5 \
            steps=${NUM_STEPS}"  # attention='dot_product'"
    

    Gemma2

    Gemma es una familia de modelos de lenguaje grandes (LLM) con pesos abiertos que desarrolló Google DeepMind, basada en la investigación y la tecnología de Gemini.

    # Requires v6e-256
    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=gemma2-27b \
        run_name=gemma2-27b-run \
        base_output_directory=${BASE_OUTPUT_DIR} \
        max_target_length=${MAX_TARGET_LENGTH} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        steps=${NUM_STEPS} \
        enable_checkpointing=false \
        use_iota_embed=true \
        gcs_metrics=true \
        dataset_type=synthetic \
        profiler=xplane \
        attention=flash
    

    Mixtral 8x7b

    Mixtral es un modelo de IA de vanguardia que desarrolló Mistral AI y que utiliza una arquitectura dispersa de mezcla de expertos (MoE).

    python3 MaxText/train.py MaxText/configs/base.yml \
        base_output_directory=${BASE_OUTPUT_DIR} \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
        model_name=mixtral-8x7b \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        tokenizer_path=assets/tokenizer.mistral-v1 \
        attention=flash \
        dtype=bfloat16 \
        dataset_type=synthetic \
        profiler=xplane
    

    Llama3-8b

    Llama es una familia de modelos de lenguaje grande (LLM) de ponderación abierta que desarrolló Meta.

    python3 MaxText/train.py MaxText/configs/base.yml \
        model_name=llama3-8b \
        base_output_directory=${BASE_OUTPUT_DIR} \
        dataset_type=synthetic \
        tokenizer_path=assets/tokenizer_llama3.tiktoken \
        per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \
        gcs_metrics=true \
        profiler=xplane \
        skip_first_n_steps_for_profiler=5 \
        steps=${NUM_STEPS} \
        max_target_length=${MAX_TARGET_LENGTH} \
        attention=flash"
    

    MaxDiffusion

    MaxDiffusion es una colección de implementaciones de referencia de varios modelos de difusión latentes escritos en Python y JAX puros que se ejecutan en dispositivos XLA, incluidas las TPU y las GPUs de Cloud. Stable Diffusion es un modelo latente de texto a imagen que genera imágenes fotorrealistas a partir de cualquier entrada de texto.

    Debes instalar una rama específica para ejecutar MaxDiffusion:

    git clone https://github.com/google/maxdiffusion.git
    && cd maxdiffusion
    && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0
    && pip install -r requirements.txt
    && pip install .
    

    Secuencia de comandos de entrenamiento:

        cd maxdiffusion && OUT_DIR=${your_own_bucket}
        python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml \
            run_name=v6e-sd2 \
            split_head_dim=True \
            attention=flash \
            train_new_unet=false \
            norm_num_groups=16 \
            output_dir=${BASE_OUTPUT_DIR} \
            per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
            [dcn_data_parallelism=2] \
            enable_profiler=True \
            skip_first_n_steps_for_profiler=95 \
            max_train_steps=${NUM_STEPS} ]
            write_metrics=True'
        
  3. Ejecuta el modelo con la secuencia de comandos que creaste en el paso anterior. Debes especificar la marca --base-docker-image para usar la imagen base de MaxText o especificar la marca --docker-image y la imagen que deseas usar.

    Opcional: Puedes habilitar el registro de depuración si incluyes la marca --enable-debug-logs. Para obtener más información, consulta Cómo depurar JAX en MaxText.

    Opcional: Puedes crear un experimento de Vertex AI para subir datos a Vertex AI TensorBoard si incluyes la marca --use-vertex-tensorboard. Para obtener más información, consulta Supervisa JAX en MaxText con Vertex AI.

    python3 xpk.py workload create \
        --cluster CLUSTER_NAME \
        {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \
        --workload ${USER}-xpk-ACCELERATOR_TYPE-NUM_SLICES \
        --tpu-type=ACCELERATOR_TYPE \
        --num-slices=NUM_SLICES  \
        --on-demand \
        --zone $ZONE \
        --project $PROJECT_ID \
        [--enable-debug-logs] \
        [--use-vertex-tensorboard] \
        --command YOUR_MODEL_SCRIPT

    Reemplaza las siguientes variables:

    • CLUSTER_NAME: Es el nombre de tu clúster de XPK.
    • ACCELERATOR_TYPE: La versión y el tamaño de tu TPU. Por ejemplo, v6e-256
    • NUM_SLICES: Es la cantidad de porciones de TPU.
    • YOUR_MODEL_SCRIPT: Es la secuencia de comandos del modelo que se ejecutará como un comando de entrenamiento.

    El resultado incluye un vínculo para seguir tu carga de trabajo, similar al siguiente:

    [XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
    

    Abre el vínculo y haz clic en la pestaña Registros para hacer un seguimiento de tu carga de trabajo en tiempo real.

Cómo depurar JAX en MaxText

Usa los comandos XPK complementarios para diagnosticar por qué no se ejecuta el clúster o la carga de trabajo:

Supervisa JAX en MaxText con Vertex AI

Consulta los datos escalares y de perfil a través de TensorBoard administrado de Vertex AI.

  1. Aumenta las solicitudes de administración de recursos (CRUD) de la zona que usas de 600 a 5,000. Esto podría no ser un problema para cargas de trabajo pequeñas que usan menos de 16 VMs.
  2. Instala dependencias como cloud-accelerator-diagnostics para Vertex AI:

    # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI
    cd ~/xpk
    pip install .
  3. Crea tu clúster de XPK con la marca --create-vertex-tensorboard, como se documenta en Crea Vertex AI TensorBoard. También puedes ejecutar este comando en clústeres existentes.

  4. Crea tu experimento de Vertex AI cuando ejecutes tu carga de trabajo de XPK con la marca --use-vertex-tensorboard y la marca opcional --experiment-name. Para obtener la lista completa de pasos, consulta Crea un experimento de Vertex AI para subir datos a Vertex AI TensorBoard.

Los registros incluyen un vínculo a un Vertex AI TensorBoard, similar al siguiente:

View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

También puedes encontrar el vínculo de Vertex AI TensorBoard en la console de Google Cloud. Ve a Experimentos de Vertex AI en la consola de Google Cloud. Selecciona la región adecuada en el menú desplegable.

El directorio de TensorBoard también se escribe en el bucket de Cloud Storage que especificaste con ${BASE_OUTPUT_DIR}.

Borra cargas de trabajo de XPK

Usa el comando xpk workload delete para borrar una o más cargas de trabajo según el prefijo o el estado del trabajo. Este comando puede ser útil si enviaste cargas de trabajo de XPK que ya no necesitan ejecutarse o si tienes trabajos que están bloqueados en la cola.

Borra el clúster de XPK

Usa el comando xpk cluster delete para borrar un clúster:

python3 xpk.py cluster delete --cluster CLUSTER_NAME --zone $ZONE --project $PROJECT_ID

Llama y PyTorch

En este instructivo, se describe cómo entrenar modelos de Llama con PyTorch/XLA en TPU v6e con el conjunto de datos de WikiText. Además, los usuarios pueden acceder a las descripciones de modelos de TPU de PyTorch como imágenes de Docker aquí.

Instalación

Instala la bifurcación pytorch-tpu/transformers de los transformadores de Hugging Face y las dependencias en un entorno virtual:

git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets
pip3 install evaluate
pip3 install scikit-learn
pip3 install accelerate

Configura los parámetros de configuración del modelo

El comando de entrenamiento en la siguiente sección, Crea la secuencia de comandos de tu modelo, usa dos archivos de configuración JSON para definir los parámetros del modelo y la configuración de FSDP (paralelismo de datos completamente fragmentado). El fragmentación de FSDP se usa para que los pesos del modelo se ajusten a un tamaño de lote más grande durante el entrenamiento. Cuando entrenas con modelos más pequeños, puede ser suficiente con usar el paralelismo de datos y replicar los pesos en cada dispositivo. Consulta la Guía del usuario de SPMD de PyTorch/XLA para obtener más detalles sobre cómo dividir tensores en varios dispositivos en PyTorch/XLA.

  1. Crea el archivo de configuración del parámetro del modelo. La siguiente es la configuración del parámetro del modelo para Llama3-8B. Para otros modelos, busca la configuración en Hugging Face. Por ejemplo, consulta la configuración de Llama2-7B.

    {
        "architectures": [
            "LlamaForCausalLM"
        ],
        "attention_bias": false,
        "attention_dropout": 0.0,
        "bos_token_id": 128000,
        "eos_token_id": 128001,
        "hidden_act": "silu",
        "hidden_size": 4096,
        "initializer_range": 0.02,
        "intermediate_size": 14336,
        "max_position_embeddings": 8192,
        "model_type": "llama",
        "num_attention_heads": 32,
        "num_hidden_layers": 32,
        "num_key_value_heads": 8,
        "pretraining_tp": 1,
        "rms_norm_eps": 1e-05,
        "rope_scaling": null,
        "rope_theta": 500000.0,
        "tie_word_embeddings": false,
        "torch_dtype": "bfloat16",
        "transformers_version": "4.40.0.dev0",
        "use_cache": false,
        "vocab_size": 128256
    }
  2. Crea el archivo de configuración de FSDP:

    {
        "fsdp_transformer_layer_cls_to_wrap": [
            "LlamaDecoderLayer"
        ],
        "xla": true,
        "xla_fsdp_v2": true,
        "xla_fsdp_grad_ckpt": true
    }

    Consulta FSDPv2 para obtener más detalles sobre el FSDP.

  3. Sube los archivos de configuración a tus VMs de TPU con el siguiente comando:

        gcloud alpha compute tpus tpu-vm scp YOUR_CONFIG_FILE.json $TPU_NAME:. \
            --worker=all \
            --project=$PROJECT \
            --zone $ZONE

    También puedes crear los archivos de configuración en tu directorio de trabajo actual y usar la marca --base-docker-image en XPK.

Compila la secuencia de comandos de tu modelo

Compila la secuencia de comandos del modelo y especifica el archivo de configuración del parámetro del modelo con la marca --config_name y el archivo de configuración de FSDP con la marca --fsdp_config. Ejecutarás esta secuencia de comandos en tu TPU en la siguiente sección, Ejecuta el modelo. Aún no ejecutes la secuencia de comandos del modelo.

    PJRT_DEVICE=TPU
    XLA_USE_SPMD=1
    ENABLE_PJRT_COMPATIBILITY=true
    # Optional variables for debugging:
    XLA_IR_DEBUG=1
    XLA_HLO_DEBUG=1
    PROFILE_EPOCH=0
    PROFILE_STEP=3
    PROFILE_DURATION_MS=100000
    PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path
    python3 transformers/examples/pytorch/language-modeling/run_clm.py \
        --dataset_name wikitext \
        --dataset_config_name wikitext-2-raw-v1 \
        --per_device_train_batch_size 8 \
        --do_train \
        --output_dir /home/$USER/tmp/test-clm \
        --overwrite_output_dir \
        --config_name /home/$USER/config-8B.json \
        --cache_dir /home/$USER/cache \
        --tokenizer_name meta-llama/Meta-Llama-3-8B \
        --block_size 8192 \
        --optim adafactor \
        --save_strategy no \
        --logging_strategy no \
        --fsdp "full_shard" \
        --fsdp_config /home/$USER/fsdp_config.json \
        --torch_dtype bfloat16 \
        --dataloader_drop_last yes \
        --flash_attention \
        --max_steps 20

Ejecuta el modelo

Ejecuta el modelo con la secuencia de comandos que creaste en el paso anterior, Compila la secuencia de comandos de tu modelo.

Si usas una VM de TPU de host único (como v6e-4), puedes ejecutar el comando de entrenamiento directamente en la VM de TPU. Si usas una VM de TPU de varios hosts, usa el siguiente comando para ejecutar la secuencia de comandos de forma simultánea en todos los hosts:

gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \
    --zone $ZONE \
    --worker=all \
    --command=YOUR_COMMAND

Soluciona problemas de PyTorch/XLA

Si configuras las variables opcionales para la depuración en la sección anterior, el perfil del modelo se almacenará en la ubicación que especifique la variable PROFILE_LOGDIR. Puedes extraer el archivo xplane.pb almacenado en esta ubicación y usar tensorboard para ver los perfiles en tu navegador con las instrucciones de TensorBoard. Si PyTorch/XLA no funciona como se espera, consulta la guía de solución de problemas, que tiene sugerencias para depurar, generar perfiles y optimizar tus modelos.

Instructivo de DLRM DCN v2

En este instructivo, se muestra cómo entrenar el modelo DLRM DCN v2 en TPU v6e.

Si ejecutas en varios hosts, restablece tpu-runtime con la versión correcta de TensorFlow ejecutando el siguiente comando. Si ejecutas la prueba en un solo host, no necesitas ejecutar los siguientes dos comandos.

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID}
--zone  ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}  --project ${PROJECT_ID} \
 --zone  ${ZONE}   \
 --worker=all \
 --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'

Establece una conexión SSH a worker-0

gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}

Establece el nombre de la TPU

export TPU_NAME=${TPU_NAME}

Ejecuta DLRM v2

pip install cloud-tpu-client

pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps

pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git

export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'

TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py  --mode=train     --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
  distribution_strategy: tpu
  mixed_precision_dtype: 'mixed_bfloat16'
task:
  use_synthetic_data: false
  use_tf_record_reader: true
  train_data:
    input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  validation_data:
    input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
    global_batch_size: 16384
    use_cached_data: true
  model:
    num_dense_features: 13
    bottom_mlp: [512, 256, 128]
    embedding_dim: 128
    interaction: 'multi_layer_dcn'
    dcn_num_layers: 3
    dcn_low_rank_dim: 512
    size_threshold: 8000
    top_mlp: [1024, 1024, 512, 256, 1]
    use_multi_hot: true
    concat_dense: false
    dcn_use_bias: true
    vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
    multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
    max_ids_per_chip_per_sample: 128
    max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
    max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
    use_partial_tpu_embedding: false
    size_threshold: 0
    initialize_tables_on_host: true
trainer:
  train_steps: 10000
  validation_interval: 1000
  validation_steps: 660
  summary_interval: 1000
  steps_per_loop: 1000
  checkpoint_interval: 0
  optimizer_config:
    embedding_optimizer: 'Adagrad'
    dense_optimizer: 'Adagrad'
    lr_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.025
      warmup_steps: 0
    dense_sgd_config:
      decay_exp: 2
      decay_start_steps: 70000
      decay_steps: 30000
      learning_rate: 0.00025
      warmup_steps: 8000
  train_tf_function: true
  train_tf_while_loop: true
  eval_tf_while_loop: true
  use_orbit: true
  pipeline_sparse_and_dense_execution: true"

Ejecuta script.sh:

chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force

Las siguientes marcas son necesarias para ejecutar cargas de trabajo de recomendación (DCN de DLRM):

ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"

Resultados de comparativas

En la siguiente sección, se incluyen los resultados de las comparativas de DLRM DCN v2 y MaxDiffusion en v6e.

DLRM DCN v2

La secuencia de comandos de entrenamiento de DLRM DCN v2 se ejecutó a diferentes escalas. Consulta las tasas de transferencia en la siguiente tabla.

v6e-64 v6e-128 v6e-256
Pasos de entrenamiento 7000 7000 7000
Tamaño del lote global 131072 262144 524288
Capacidad de procesamiento (ejemplos/s) 2975334 5111808 10066329

MaxDiffusion

Ejecutamos la secuencia de comandos de entrenamiento para MaxDiffusion en una v6e-4, una v6e-16 y una 2xv6e-16. Consulta las tasas de transferencia en la siguiente tabla.

v6e-4 v6e-16 Dos v6e-16
Pasos de entrenamiento 0.069 0.073 0.13
Tamaño del lote global 8 32 64
Capacidad de procesamiento (ejemplos/s) 115.9 438.4 492.3

Colecciones

La versión 6e presenta una nueva función llamada colecciones para beneficiar a los usuarios que ejecutan cargas de trabajo de publicación. La función de colecciones solo se aplica a la versión 6e.

Las colecciones te permiten indicarle a Google Cloud cuáles de tus nodos de TPU forman parte de una carga de trabajo de publicación. Esto permite que la infraestructura subyacente de Google Cloud limite y optimice las interrupciones que se pueden aplicar a las cargas de trabajo de entrenamiento en el curso normal de las operaciones.

Usa colecciones de la API de Cloud TPU

Una colección de un solo host en la API de Cloud TPU es un recurso en cola en el que se establece una marca especial (--workload-type = availability-optimized) para indicarle a la infraestructura subyacente que se debe usar para entregar cargas de trabajo.

El siguiente comando aprovisiona una colección de host único con la API de Cloud TPU:

gcloud alpha compute tpus queued-resources create COLLECTION_NAME \
   --project=project name \
   --zone=zone name \
   --accelerator-type=accelerator type \
   --node-count=number of nodes \
   --workload-type=availability-optimized

Supervisa y crea perfiles

Cloud TPU v6e admite la supervisión y la generación de perfiles con los mismos métodos que las generaciones anteriores de Cloud TPU. Para obtener más información sobre la supervisión, consulta Supervisa VMs de TPU.