Introducción a Trillium (v6e)
En esta documentación, la API de TPU y los registros, se usa v6e para referirse a Trillium. v6e representa la 6ª generación de TPU de Google.
Con 256 chips por Pod, la v6e comparte muchas similitudes con la v5e. Este sistema está optimizado para ser el producto de mayor valor para el entrenamiento, la optimización y la publicación de transformadores, texto a imagen y redes neuronales convolucionales (CNN).
Arquitectura del sistema v6e
Para obtener información sobre la configuración de Cloud TPU, consulta la documentación de v6e.
En este documento, se enfoca en el proceso de configuración para el entrenamiento de modelos con los frameworks de JAX, PyTorch o TensorFlow. Con cada framework, puedes aprovisionar TPU con recursos en cola o Google Kubernetes Engine (GKE). La configuración de GKE se puede realizar con comandos de XPK o de GKE.
Prepara un proyecto de Google Cloud
- Accede a tu Cuenta de Google. Si aún no lo hiciste, regístrate para obtener una nueva cuenta.
- En la consola de Google Cloud, selecciona o crea un proyecto de Cloud en la página del selector de proyectos.
- Habilita la facturación para tu proyecto de Google Cloud. La facturación es obligatoria para todo el uso de Google Cloud.
- Instala los componentes de gcloud alpha.
Ejecuta el siguiente comando para instalar la versión más reciente de los componentes de
gcloud
.gcloud components update
Habilita la API de TPU con el siguiente comando
gcloud
en Cloud Shell. También puedes habilitarlo desde la consola de Google Cloud.gcloud services enable tpu.googleapis.com
Habilita los permisos con la cuenta de servicio de TPU para la API de Compute Engine
Las cuentas de servicio permiten que el servicio de Cloud TPU acceda a otros servicios de Google Cloud. Una cuenta de servicio administrada por el usuario es una práctica recomendada de Google Cloud. Sigue estas guías para crear y otorgar roles. Se requieren los siguientes roles:
- Administrador de TPU
- Administrador de almacenamiento
- Escritor de registros
- Escritor de métricas de Monitoring
a. Configura los permisos de XPK con tu cuenta de usuario de GKE: XPK.
Crea variables de entorno para el ID y la zona del proyecto.
gcloud auth login gcloud config set project ${PROJECT_ID} gcloud config set compute/zone ${ZONE}
Crea una identidad de servicio para la VM de TPU.
gcloud alpha compute tpus tpu-vm service-identity create --zone=${ZONE}
Capacidad segura
Comunícate con el equipo de asistencia de ventas o de cuentas de Cloud TPU para solicitar la cuota de TPU y responder cualquier pregunta sobre la capacidad.
Aprovisiona el entorno de Cloud TPU
Las TPU v6e se pueden aprovisionar y administrar con GKE, con GKE y XPK (una herramienta de wrapper de CLI sobre GKE) o como recursos en fila.
Requisitos previos
- Verifica que tu proyecto tenga suficiente cuota de
TPUS_PER_TPU_FAMILY
, que especifica la cantidad máxima de chips a los que puedes acceder en tu proyecto de Google Cloud. - La versión 6e se probó con la siguiente configuración:
- Python
3.10
o versiones posteriores - Versiones de software nocturnas:
- JAX nocturno
0.4.32.dev20240912
- LibTPU nocturna
0.1.dev20240912+nightly
- JAX nocturno
- Versiones de software estables:
- JAX + JAX Lib de la versión 0.4.35
- Python
- Verifica que tu proyecto tenga suficiente cuota de TPU para lo siguiente:
- Cuota de VM de TPU
- Cuota de direcciones IP
- Quota de Hyperdisk-balance
- Permisos del proyecto del usuario
- Si usas GKE con XPK, consulta Permisos de la consola de Google Cloud en la cuenta de usuario o de servicio para conocer los permisos necesarios para ejecutar XPK.
Variables de entorno
En Cloud Shell, crea las siguientes variables de entorno:
export NODE_ID=TPU_NODE_ID # TPU name export PROJECT_ID=PROJECT_ID export ACCELERATOR_TYPE=v6e-16 export ZONE=us-central2-b export RUNTIME_VERSION=v2-alpha-tpuv6e export SERVICE_ACCOUNT=YOUR_SERVICE_ACCOUNT export QUEUED_RESOURCE_ID=QUEUED_RESOURCE_ID export VALID_DURATION=VALID_DURATION # Additional environment variable needed for Multislice: export NUM_SLICES=NUM_SLICES # Use a custom network for better performance as well as to avoid having the # default network becoming overloaded. export NETWORK_NAME=${PROJECT_ID}-mtu9k export NETWORK_FW_NAME=${NETWORK_NAME}-fw
Descripciones de las marcas de comandos
Variable | Descripción |
NODE_ID | El ID asignado por el usuario de la TPU que se crea cuando se asigna la solicitud de recurso en fila. |
ID DEL PROYECTO | Nombre del proyecto de Google Cloud. Usa un proyecto existente o crea uno nuevo en |
ZONA | Consulta el documento Regiones y zonas de TPU para conocer las zonas compatibles. |
ACCELERATOR_TYPE | Consulta Tipos de aceleradores. |
RUNTIME_VERSION | v2-alpha-tpuv6e
|
SERVICE_ACCOUNT | Esta es la dirección de correo electrónico de tu cuenta de servicio que puedes encontrar en
la Google Cloud Console -> IAM -> Cuentas de servicio.
Por ejemplo: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com |
NUM_SLICES | Es la cantidad de rebanadas que se deben crear (solo es necesaria para Multislice). |
QUEUED_RESOURCE_ID | El ID de texto asignado por el usuario de la solicitud de recursos en cola. |
VALID_DURATION | Es la duración durante la cual la solicitud de recursos en cola es válida. |
NETWORK_NAME | Es el nombre de una red secundaria que se usará. |
NETWORK_FW_NAME | Es el nombre de un firewall de red secundario que se usará. |
Optimizaciones del rendimiento de la red
Para obtener el mejor rendimiento, usa una red con 8,896 MTU (unidad de transmisión máxima).
De forma predeterminada, una nube privada virtual (VPC) solo proporciona una MTU de 1,460 bytes, lo que proporcionará un rendimiento de red poco óptimo. Puedes configurar la MTU de una red de VPC en cualquier valor entre 1,300 bytes y 8,896 bytes (inclusive). Los tamaños de MTU personalizados comunes son 1,500 bytes (Ethernet estándar) o 8,896 bytes (el máximo posible). Para obtener más información, consulta Tamaños válidos de MTU de la red de VPC.
Para obtener más información sobre cómo cambiar la configuración de MTU de una red existente o predeterminada, consulta Cambia la configuración de MTU de una red de VPC.
En el siguiente ejemplo, se crea una red con 8,896 MTU.
export RESOURCE_NAME=RESOURCE_NAME export NETWORK_NAME=${RESOURCE_NAME} export NETWORK_FW_NAME=${RESOURCE_NAME} export PROJECT=X gcloud compute networks create ${NETWORK_NAME} --mtu=8896 --project=${PROJECT} --subnet-mode=auto --bgp-routing-mode=regional gcloud compute firewall-rules create ${NETWORK_FW_NAME} --network ${NETWORK_NAME} \
Cómo usar varias NIC (opción para Multislice)
Las siguientes variables de entorno son necesarias para una subred secundaria cuando usas un entorno de Multislice.
export NETWORK_NAME_2=${RESOURCE_NAME}
export SUBNET_NAME_2=${RESOURCE_NAME}
export FIREWALL_RULE_NAME=${RESOURCE_NAME}
export ROUTER_NAME=${RESOURCE_NAME}-network-2
export NAT_CONFIG=${RESOURCE_NAME}-natconfig-2
export REGION=us-central2
Usa los siguientes comandos para crear enrutamiento de IP personalizado para la red y la subred.
gcloud compute networks create "${NETWORK_NAME_2}" --mtu=8896
--bgp-routing-mode=regional --subnet-mode=custom --project=$PROJECT
gcloud compute networks subnets create "${SUBNET_NAME_2}" \
--network="${NETWORK_NAME_2}" \
--range=10.10.0.0/18 --region="${REGION}" \
--project=$PROJECT
gcloud compute firewall-rules create "${FIREWALL_RULE_NAME}" \
--network "${NETWORK_NAME_2}" --allow tcp,icmp,udp \
--source-ranges 10.10.0.0/18 --project="${PROJECT}"
gcloud compute routers create "${ROUTER_NAME}" \
--project="${PROJECT}" \
--network="${NETWORK_NAME_2}" \
--region="${REGION}"
gcloud compute routers nats create "${NAT_CONFIG}" \
--router="${ROUTER_NAME}" \
--region="${REGION}" \
--auto-allocate-nat-external-ips \
--nat-all-subnet-ip-ranges \
--project="${PROJECT}" \
--enable-logging
Una vez que se crea una porción de varias redes, puedes validar que se usen ambas NIC ejecutando --command ifconfig
como parte de la carga de trabajo de XPK. Luego, observa el resultado impreso de esa carga de trabajo de XPK en los registros de la consola de Cloud y verifica que tanto eth0 como eth1 tengan mtu=8896.
python3 xpk.py workload create \ --cluster ${CLUSTER_NAME} \ (--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}) \ --workload ${USER}-xpk-$ACCELERATOR_TYPE-$NUM_SLICES \ --tpu-type=${ACCELERATOR_TYPE} \ --num-slices=${NUM_SLICES} \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command "ifconfig"
Verifica que tanto eth0 como eth1 tengan mtu=8,896. Una forma de verificar que tienes varios NIC en ejecución es ejecutar el comando --command "ifconfig" como parte de la carga de trabajo de XPK. Luego, observa el resultado impreso de esa carga de trabajo de xpk en los registros de la consola de Cloud y verifica que tanto eth0 como eth1 tengan mtu=8896.
Configuración de TCP mejorada
En el caso de las TPU creadas con la interfaz de recursos en cola, puedes ejecutar el siguiente comando para mejorar el rendimiento de la red cambiando la configuración predeterminada de TCP para rto_min
y quickack
.
gcloud alpha compute tpus queued-resources ssh "${QUEUED_RESOURCE_ID}" \ --project "$PROJECT" --zone "${ZONE}" \ --command='ip route show | while IFS= read -r route; do if ! echo $route | \ grep -q linkdown; then sudo ip route change ${route/lock/} rto_min 5ms quickack 1; fi; done' \ --worker=all
Aprovisionamiento con recursos en cola (API de Cloud TPU)
La capacidad se puede aprovisionar con el comando create
de recursos en cola.
Crea una solicitud de recurso en cola de TPU.
La marca
--reserved
solo es necesaria para los recursos reservados, no para los recursos on demand.gcloud alpha compute tpus queued-resources create ${QUEUED_RESOURCE_ID} \ --node-id ${TPU_NAME} \ --project ${PROJECT_ID} \ --zone ${ZONE} \ --accelerator-type ${ACCELERATOR_TYPE} \ --runtime-version ${RUNTIME_VERSION} \ --valid-until-duration ${VALID_DURATION} \ --service-account ${SERVICE_ACCOUNT} \ [--reserved]
Si la solicitud de recursos en cola se crea correctamente, el estado en el campo "response" será "WAITING_FOR_RESOURCES" o "FAILED". Si la solicitud de recursos en cola está en el estado "WAITING_FOR_RESOURCES", significa que el recurso en cola se agregó a la cola y se aprovisionará cuando haya suficiente capacidad de TPU. Si la solicitud de recursos en cola está en el estado "FAILED", el motivo de la falla aparecerá en el resultado. La solicitud de recursos en cola vencerá si no se aprovisiona un v6e dentro de la duración especificada y el estado se convierte en "FAILED". Consulta la documentación pública de Recursos en cola para obtener más información.
Cuando tu solicitud de recursos en cola esté en el estado "ACTIVO", podrás conectarte a tus VMs de TPU con SSH. Usa los comandos
list
odescribe
para consultar el estado de tu recurso en cola.gcloud alpha compute tpus queued-resources describe ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE}
Cuando el recurso en cola está en el estado "ACTIVE", el resultado es similar al siguiente:
state: state: ACTIVE
Administra tus VMs de TPU. Para conocer las opciones de administración de tus VMs de TPU, consulta Cómo administrar VMs de TPU.
Conéctate a tus VMs de TPU con SSH
Puedes instalar objetos binarios en cada VM de TPU de tu porción de TPU y ejecutar código. Consulta la sección Tipos de VM para determinar cuántas VMs tendrá tu fragmento.
Para instalar los objetos binarios o ejecutar código, puedes usar SSH para conectarte a una VM con el comando
tpu-vm ssh
.gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ --node=all # add this flag if you are using Multislice
Para usar SSH y conectarte a una VM específica, usa la marca
--worker
que sigue a un índice basado en 0:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --worker=1
Si tienes formas de rebanada superiores a 8 chips, tendrás varias VM en una rebanada. En este caso, usa los parámetros
--worker=all
y--command
en el comandogcloud alpha compute tpus tpu-vm ssh
para ejecutar un comando en todas las VMs de forma simultánea. Por ejemplo:gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \ --zone ${ZONE} --worker=all \ --command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
Borrar un recurso en cola
Borra un recurso en cola al final de la sesión o quita las solicitudes de recursos en cola que estén en el estado "FAILED". Para borrar un recurso en cola, borra la porción y, luego, la solicitud de recurso en cola en 2 pasos:
gcloud alpha compute tpus tpu-vm delete $TPU_NAME --project=${PROJECT_ID} \ --zone=${ZONE} --quiet gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet
gcloud alpha compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \ --project ${PROJECT_ID} --zone ${ZONE} --quiet --force
Usa GKE con v6e
Si usas comandos de GKE con v6e, puedes usar comandos de Kubernetes o XPK para aprovisionar TPU y entrenar o entregar modelos. Consulta Planifica el uso de TPU en GKE para aprender a usar GKE con TPU y v6e.
Configuración del framework
En esta sección, se describe el proceso de configuración general para el entrenamiento de modelos de AA con los frameworks JAX, PyTorch o TensorFlow. Puedes aprovisionar TPUs con recursos en cola o GKE. La configuración de GKE se puede realizar con XPK o comandos de Kubernetes.
Configura JAX con recursos en cola
Instala JAX en todas las VMs de TPU de tu porción o porciones de forma simultánea con gcloud alpha compute tpus tpu-vm ssh
. Para Multislice, agrega --node=all
.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html</code>'
Puedes ejecutar el siguiente código de Python para verificar cuántos núcleos de TPU están disponibles en tu fragmento y probar que todo esté instalado correctamente (los resultados que se muestran aquí se produjeron con un fragmento v6e-16):
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all \
--command='python3 -c "import jax; print(jax.device_count(), jax.local_device_count())"'
El resultado es similar a este:
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
16 4
16 4
16 4
16 4
jax.device_count() muestra la cantidad total de chips en la porción determinada. jax.local_device_count() indica la cantidad de chips a los que puede acceder una sola VM en esta porción.
gcloud alpha compute tpus queued-resources ssh ${QUEUED_RESOURCE_ID} \
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
--command='git clone -b mlperf4.1 https://github.com/google/maxdiffusion.git &&
cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 &&
&& pip install -r requirements.txt && pip install . '
Cómo solucionar problemas de configuración de JAX
Una sugerencia general es habilitar el registro detallado en el manifiesto de tu carga de trabajo de GKE. Luego, proporciona los registros al equipo de asistencia de GKE.
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
Mensajes de error
no endpoints available for service 'jobset-webhook-service'
Este error significa que el conjunto de trabajos no se instaló correctamente. Verifica si se están ejecutando los pods de Kubernetes de la implementación de jobset-controller-manager. Para obtener más información, consulta la documentación de solución de problemas de JobSet.
TPU initialization failed: Failed to connect
Asegúrate de que la versión de tu nodo de GKE sea 1.30.4-gke.1348000 o posterior (no se admite GKE 1.31).
Configuración para PyTorch
En esta sección, se describe cómo comenzar a usar PJRT en v6e con PyTorch/XLA. La versión recomendada de Python es 3.10.
Configura PyTorch con GKE y XPK
Puedes usar el siguiente contenedor de Docker con XPK, que ya tiene instaladas las dependencias de PyTorch:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028
Para crear una carga de trabajo de XPK, usa el siguiente comando:
python3 xpk.py workload create \
--cluster ${CLUSTER_NAME} \
[--docker-image | --base-docker-image] us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20241028 \
--workload ${USER} -xpk-${ACCELERATOR_TYPE} -$NUM_SLICES \
--tpu-type=${ACCELERATOR_TYPE} \
--num-slices=${NUM_SLICES} \
--on-demand \
--zone ${ZONE} \
--project ${PROJECT_ID} \
--enable-debug-logs \
--command 'python3 -c "import torch; import torch_xla; import torch_xla.runtime as xr; print(xr.global_runtime_device_count())"'
El uso de --base-docker-image
crea una nueva imagen de Docker con el directorio de trabajo
actual integrado en el nuevo Docker.
Configura PyTorch con recursos en cola
Sigue estos pasos para instalar PyTorch con recursos en cola y ejecutar una pequeña secuencia de comandos en v6e.
Instala dependencias con SSH para acceder a las VMs.
Para Multislice, agrega --node=all
:
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
--project=${PROJECT_ID} \
--zone=${ZONE} \
--worker=all \
--command='sudo apt install -y libopenblas-base pip3 \
install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html'
Mejora el rendimiento de los modelos con asignaciones grandes y frecuentes
En el caso de los modelos que tienen asignaciones frecuentes y de tamaño, observamos que el uso de tcmalloc
mejora el rendimiento de manera significativa en comparación con la implementación predeterminada de malloc
, por lo que el malloc
predeterminado que se usa en la VM de TPU es tcmalloc
. Sin embargo, según tu carga de trabajo (por ejemplo, con DLRM, que tiene asignaciones muy grandes para sus tablas de incorporación), tcmalloc
puede causar una ralentización, en cuyo caso puedes intentar restablecer la siguiente variable con el malloc
predeterminado:
unset LD_PRELOAD
Usa una secuencia de comandos de Python para realizar un cálculo en la VM v6e:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME}
--project ${PROJECT_ID} \
--zone ${ZONE} --worker all --command='
unset LD_PRELOAD
python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"
'
Esto genera un resultado similar al que se muestra a continuación:
SSH: Attempting to connect to worker 0...
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
xla:0
tensor([[ 0.3355, -1.4628, -3.2610],
[-1.4656, 0.3196, -2.8766],
[ 0.8668, -1.5060, 0.7125]], device='xla:0')
Configuración para TensorFlow
Para la versión preliminar pública de v6e, solo se admite la versión del entorno de ejecución tf-nightly.
Para restablecer tpu-runtime
con la versión compatible con TensorFlow v6e, ejecuta los siguientes comandos:
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} --worker=all --command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Usa SSH para acceder a worker-0:
$ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE}
Instala TensorFlow en worker-0:
sudo apt install -y libopenblas-base
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310
pip install cloud-tpu-client
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Exporta la variable de entorno TPU_NAME
:
export TPU_NAME=v6e-16
Puedes ejecutar la siguiente secuencia de comandos de Python para verificar cuántos núcleos de TPU están disponibles en tu fragmento y probar que todo esté instalado correctamente (los resultados que se muestran se generaron con un fragmento v6e-16):
import TensorFlow as tf
print("TensorFlow version " + tf.__version__)
@tf.function
def add_fn(x,y):
z = x + y
return z
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
x = tf.constant(1.)
y = tf.constant(1.)
z = strategy.run(add_fn, args=(x,y))
print(z)
El resultado es similar a este:
PerReplica:{
0: tf.Tensor(2.0, shape=(), dtype=float32),
1: tf.Tensor(2.0, shape=(), dtype=float32),
2: tf.Tensor(2.0, shape=(), dtype=float32),
3: tf.Tensor(2.0, shape=(), dtype=float32),
4: tf.Tensor(2.0, shape=(), dtype=float32),
5: tf.Tensor(2.0, shape=(), dtype=float32),
6: tf.Tensor(2.0, shape=(), dtype=float32),
7: tf.Tensor(2.0, shape=(), dtype=float32)
}
v6e con SkyPilot
Puedes usar TPU v6e con SkyPilot. Sigue estos pasos para agregar información de ubicación o precios relacionada con la v6e a SkyPilot.
Agrega lo siguiente al final de
~/.sky/catalogs/v5/gcp/vms.csv
:,,,tpu-v6e-1,1,tpu-v6e-1,us-south1,us-south1-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-1,1,tpu-v6e-1,us-east5,us-east5-b,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-south1,us-south1-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-4,1,tpu-v6e-4,us-east5,us-east5-b,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-south1,us-south1-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-8,1,tpu-v6e-8,us-east5,us-east5-b,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-south1,us-south1-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-16,1,tpu-v6e-16,us-east5,us-east5-b,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-south1,us-south1-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-32,1,tpu-v6e-32,us-east5,us-east5-b,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-south1,us-south1-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-64,1,tpu-v6e-64,us-east5,us-east5-b,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-south1,us-south1-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-128,1,tpu-v6e-128,us-east5,us-east5-b,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-south1,us-south1-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,europe-west4,europe-west4-a,0,0 ,,,tpu-v6e-256,1,tpu-v6e-256,us-east5,us-east5-b,0,0
Especifica los siguientes recursos en un archivo YAML:
# tpu_v6.yaml resources: accelerators: tpu-v6e-16 # Fill in the accelerator type you want to use accelerator_args: runtime_version: v2-alpha-tpuv6e # Official suggested runtime
Inicia un clúster con TPU v6e:
sky launch tpu_v6.yaml -c tpu_v6
Conéctate a la TPU v6e con SSH:
ssh tpu_v6
Instructivos de inferencia
En las siguientes secciones, se proporcionan instructivos para la entrega de modelos de MaxText y PyTorch con JetStream, así como la entrega de modelos de MaxDiffusion en TPU v6e.
MaxText en JetStream
En este instructivo, se muestra cómo usar JetStream para entregar modelos de MaxText (JAX) en TPU v6e. JetStream es un motor con capacidad de procesamiento y memoria optimizada para la inferencia de modelos de lenguaje grandes (LLM) en dispositivos XLA (TPU). En este instructivo, ejecutarás la comparativa de inferencia para el modelo Llama2-7B.
Antes de comenzar
Crea una TPU v6e con 4 chips:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Conéctate a la TPU con SSH:
gcloud compute tpus tpu-vm ssh TPU_NAME
Ejecuta el instructivo
Para configurar JetStream y MaxText, convertir los puntos de control del modelo y ejecutar la comparativa de inferencia, sigue las instrucciones en el repositorio de GitHub.
Limpia
Borra la TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async
vLLM en PyTorch TPU
A continuación, se incluye un instructivo sencillo que muestra cómo comenzar a usar vLLM en una VM de TPU. En nuestro ejemplo de prácticas recomendadas para implementar vLLM en Trillium en producción, publicaremos una guía del usuario de GKE en los próximos días (¡no te pierdas las novedades!).
Antes de comenzar
Crea una TPU v6e con 4 chips:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Descripciones de las marcas de comandos
Variable Descripción NODE_ID El ID asignado por el usuario de la TPU que se crea cuando se asigna la solicitud de recurso en fila. ID DEL PROYECTO Nombre del proyecto de Google Cloud. Usa un proyecto existente o crea uno nuevo en . ZONA Consulta el documento Regiones y zonas de TPU para conocer las zonas compatibles. ACCELERATOR_TYPE Consulta Tipos de aceleradores. RUNTIME_VERSION v2-alpha-tpuv6e
SERVICE_ACCOUNT Esta es la dirección de correo electrónico de tu cuenta de servicio que puedes encontrar en la Google Cloud Console -> IAM -> Cuentas de servicio. Por ejemplo: tpu-service-account@<your_project_ID>.iam.gserviceaccount.com.com
Conéctate a la TPU con SSH:
gcloud compute tpus tpu-vm ssh TPU_NAME
Create a Conda environment
(Recommended) Create a new conda environment for vLLM:
conda create -n vllm python=3.10 -y conda activate vllm
Configura vLLM en TPU
Clona el repositorio de vLLM y navega al directorio vLLM:
git clone https://github.com/vllm-project/vllm.git && cd vllm
Limpia los paquetes torch y torch-xla existentes:
pip uninstall torch torch-xla -y
Instala PyTorch y PyTorch XLA:
pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
Instala JAX y Pallas:
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Instala otras dependencias de compilación:
pip install -r requirements-tpu.txt VLLM_TARGET_DEVICE="tpu" python setup.py develop sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
Obtén acceso al modelo
Debes firmar el acuerdo de consentimiento para usar la familia de modelos Llama3 en el repositorio de HuggingFace.
Genera un nuevo token de Hugging Face si aún no tienes uno:
- Haz clic en Tu perfil > Configuración > Tokens de acceso.
- Selecciona Token nuevo.
- Especifica el nombre que desees y un rol de al menos
Read
. - Selecciona Generate un token.
Copia el token generado en el portapapeles, configúralo como una variable de entorno y realiza la autenticación con huggingface-cli:
export TOKEN='' git config --global credential.helper store huggingface-cli login --token $TOKEN
Descargar datos de comparativas
Crea un directorio /data y descarga el conjunto de datos de ShareGPT desde Hugging Face.
mkdir ~/data && cd ~/data wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
Inicia el servidor vLLM
El siguiente comando descarga los pesos del modelo de Hugging Face Model Hub al directorio /tmp de la VM de TPU, precompila un rango de formas de entrada y escribe la compilación del modelo en ~/.cache/vllm/xla_cache
.
Para obtener más detalles, consulta la documentación de vLLM.
cd ~/vllm
vllm serve "meta-llama/Meta-Llama-3.1-8B" --download_dir /tmp --num-scheduler-steps 4 --swap-space 16 --disable-log-requests --tensor_parallel_size=4 --max-model-len=2048 &> serve.log &
Ejecuta comparativas de vLLM
Ejecuta la secuencia de comandos de comparativas de vLLM:
python benchmarks/benchmark_serving.py \
--backend vllm \
--model "meta-llama/Meta-Llama-3.1-8B" \
--dataset-name sharegpt \
--dataset-path ~/data/ShareGPT_V3_unfiltered_cleaned_split.json \
--num-prompts 1000
Limpia
Borra la TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async
PyTorch en JetStream
En este instructivo, se muestra cómo usar JetStream para entregar modelos de PyTorch en TPU v6e. JetStream es un motor con capacidad de procesamiento y memoria optimizada para la inferencia de modelos de lenguaje grandes (LLM) en dispositivos XLA (TPU). En este instructivo, ejecutarás la comparativa de inferencia para el modelo Llama2-7B.
Antes de comenzar
Crea una TPU v6e con 4 chips:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Conéctate a la TPU con SSH:
gcloud compute tpus tpu-vm ssh TPU_NAME
Ejecuta el instructivo
Para configurar JetStream-PyTorch, convertir los puntos de control del modelo y ejecutar la comparativa de inferencia, sigue las instrucciones en el repositorio de GitHub.
Limpia
Borra la TPU:
gcloud compute tpus queued-resources delete ${QUEUED_RESOURCE_ID} \
--project ${PROJECT_ID} \
--zone ${ZONE} \
--force \
--async
Inferencia de MaxDiffusion
En este instructivo, se muestra cómo entregar modelos de MaxDiffusion en TPU v6e. En este instructivo, generarás imágenes con el modelo Stable Diffusion XL.
Antes de comenzar
Crea una TPU v6e con 4 chips:
gcloud alpha compute tpus queued-resources create QUEUED_RESOURCE_ID \ --node-id TPU_NAME \ --project PROJECT_ID \ --zone ZONE \ --accelerator-type v6e-4 \ --runtime-version v2-alpha-tpuv6e \ --service-account SERVICE_ACCOUNT
Conéctate a la TPU con SSH:
gcloud compute tpus tpu-vm ssh TPU_NAME
Crea un entorno de Conda
Crea un directorio para Miniconda:
mkdir -p ~/miniconda3
Descarga la secuencia de comandos del instalador de Miniconda:
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
Instala Miniconda:
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
Quita la secuencia de comandos del instalador de Miniconda:
rm -rf ~/miniconda3/miniconda.sh
Agrega Miniconda a tu variable
PATH
:export PATH="$HOME/miniconda3/bin:$PATH"
Vuelve a cargar
~/.bashrc
para aplicar los cambios a la variablePATH
:source ~/.bashrc
Crea un nuevo entorno de Conda:
conda create -n tpu python=3.10
Activa el entorno de Conda:
source activate tpu
Configura MaxDiffusion
Clona el repositorio de MaxDiffusion y navega al directorio MaxDiffusion:
https://github.com/google/maxdiffusion.git && cd maxdiffusion
Cambia a la rama
mlperf-4.1
:git checkout mlperf4.1
Instala MaxDiffusion:
pip install -e .
Instala las dependencias:
pip install -r requirements.txt
Instala JAX:
pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Generar imágenes
Establece variables de entorno para configurar el entorno de ejecución de TPU:
LIBTPU_INIT_ARGS="--xla_tpu_rwb_fusion=false --xla_tpu_dot_dot_fusion_duplicated=true --xla_tpu_scoped_vmem_limit_kib=65536"
Genera imágenes con la instrucción y las configuraciones definidas en
src/maxdiffusion/configs/base_xl.yml
:python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run"
Limpia
Borra la TPU:
gcloud compute tpus queued-resources delete QUEUED_RESOURCE_ID \ --project PROJECT_ID \ --zone ZONE \ --force \ --async
Instructivos de capacitación
En las siguientes secciones, se proporcionan instructivos para entrenar MaxText.
Modelos de MaxDiffusion y PyTorch en TPU v6e
MaxText y MaxDiffusion
En las siguientes secciones, se describe el ciclo de vida de entrenamiento de los modelos MaxText y MaxDiffusion.
En general, los pasos de alto nivel son los siguientes:
- Compila la imagen base de la carga de trabajo.
- Ejecuta tu carga de trabajo con XPK.
- Compila el comando de entrenamiento para la carga de trabajo.
- Implementa la carga de trabajo.
- Sigue la carga de trabajo y consulta las métricas.
- Borra la carga de trabajo de XPK si no la necesitas.
- Borra el clúster de XPK cuando ya no lo necesites.
Compila la imagen base
Instala MaxText o MaxDiffusion y compila la imagen de Docker:
Clona el repositorio que deseas usar y cambia al directorio del repositorio:
MaxText:
git clone https://github.com/google/maxtext.git && cd maxtext
MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion
Configura Docker para usar Google Cloud CLI:
gcloud auth configure-docker
Compila la imagen de Docker con el siguiente comando o con la pila estable de JAX. Para obtener más información sobre JAX Stable Stack, consulta Cómo compilar una imagen de Docker con JAX Stable Stack.
bash docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.35
Si inicias la carga de trabajo desde una máquina que no tiene la imagen compilada de forma local, sube la imagen:
bash docker_upload_runner.sh CLOUD_IMAGE_NAME=${USER}_runner
Compila una imagen de Docker con JAX Stable Stack
Puedes compilar las imágenes de Docker de MaxText y MaxDiffusion con la imagen base de la pila estable de JAX.
La pila estable de JAX proporciona un entorno coherente para MaxText y MaxDiffusion, ya que agrupa JAX con paquetes principales, como orbax
, flax
y optax
, junto con una libtpu.so bien calificada que impulsa las utilidades del programa de TPU y otras herramientas esenciales. Estas bibliotecas se prueban para garantizar la compatibilidad, lo que proporciona una base estable para compilar y ejecutar MaxText y MaxDiffusion, y eliminar posibles conflictos debido a versiones de paquetes incompatibles.
La pila estable de JAX incluye una libtpu.so completamente publicada y calificada, la biblioteca principal que impulsa la compilación, la ejecución y la configuración de la red de ICI del programa de TPU. La versión de libtpu reemplaza la compilación nocturna que usaba JAX anteriormente y garantiza una funcionalidad coherente de los cálculos de XLA en TPU con pruebas de calificación a nivel de PJRT en IR de HLO/StableHLO.
Para compilar la imagen de Docker de MaxText y MaxDiffusion con la pila estable de JAX, cuando ejecutes la secuencia de comandos docker_build_dependency_image.sh
, establece la variable MODE
en stable_stack
y la variable BASEIMAGE
en la imagen base que deseas usar.
En el siguiente ejemplo, se especifica us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
como la imagen base:
bash docker_build_dependency_image.sh MODE=stable_stack BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1
Para obtener una lista de las imágenes base de JAX Stable Stack disponibles, consulta Imágenes de JAX Stable Stack en Artifact Registry.
Ejecuta tu carga de trabajo con XPK
Configura las siguientes variables de entorno si no usas los valores predeterminados que establece MaxText o MaxDiffusion:
BASE_OUTPUT_DIR=gs://YOUR_BUCKET PER_DEVICE_BATCH_SIZE=2 NUM_STEPS=30 MAX_TARGET_LENGTH=8192
Compila la secuencia de comandos del modelo para que se copie como un comando de entrenamiento en el siguiente paso. Aún no ejecutes la secuencia de comandos del modelo.
MaxText
MaxText es un LLM de código abierto, altamente escalable y de alto rendimiento, escrito en Python y JAX puros, y se orienta a las TPU y las GPUs de Google Cloud para el entrenamiento y la inferencia.
JAX_PLATFORMS=tpu,cpu \ ENABLE_PJRT_COMPATIBILITY=true \ TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true \ TPU_SLICE_BUILDER_DUMP_ICI=true && \ python /deps/MaxText/train.py /deps/MaxText/configs/base.yml \ base_output_directory=$BASE_OUTPUT_DIR \ dataset_type=synthetic \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ enable_checkpointing=false \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS}" # attention='dot_product'"
Gemma2
Gemma es una familia de modelos de lenguaje grandes (LLM) con pesos abiertos que desarrolló Google DeepMind, basada en la investigación y la tecnología de Gemini.
# Requires v6e-256 python3 MaxText/train.py MaxText/configs/base.yml \ model_name=gemma2-27b \ run_name=gemma2-27b-run \ base_output_directory=${BASE_OUTPUT_DIR} \ max_target_length=${MAX_TARGET_LENGTH} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ steps=${NUM_STEPS} \ enable_checkpointing=false \ use_iota_embed=true \ gcs_metrics=true \ dataset_type=synthetic \ profiler=xplane \ attention=flash
Mixtral 8x7b
Mixtral es un modelo de IA de vanguardia que desarrolló Mistral AI y que utiliza una arquitectura dispersa de mezcla de expertos (MoE).
python3 MaxText/train.py MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ model_name=mixtral-8x7b \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ tokenizer_path=assets/tokenizer.mistral-v1 \ attention=flash \ dtype=bfloat16 \ dataset_type=synthetic \ profiler=xplane
Llama3-8b
Llama es una familia de modelos de lenguaje grande (LLM) de ponderación abierta que desarrolló Meta.
python3 MaxText/train.py MaxText/configs/base.yml \ model_name=llama3-8b \ base_output_directory=${BASE_OUTPUT_DIR} \ dataset_type=synthetic \ tokenizer_path=assets/tokenizer_llama3.tiktoken \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} # set to 4 \ gcs_metrics=true \ profiler=xplane \ skip_first_n_steps_for_profiler=5 \ steps=${NUM_STEPS} \ max_target_length=${MAX_TARGET_LENGTH} \ attention=flash"
MaxDiffusion
MaxDiffusion es una colección de implementaciones de referencia de varios modelos de difusión latentes escritos en Python y JAX puros que se ejecutan en dispositivos XLA, incluidas las TPU y las GPUs de Cloud. Stable Diffusion es un modelo latente de texto a imagen que genera imágenes fotorrealistas a partir de cualquier entrada de texto.
Debes instalar una rama específica para ejecutar MaxDiffusion:
git clone https://github.com/google/maxdiffusion.git && cd maxdiffusion && git checkout e712c9fc4cca764b0930067b6e33daae2433abf0 && pip install -r requirements.txt && pip install .
Secuencia de comandos de entrenamiento:
cd maxdiffusion && OUT_DIR=${your_own_bucket} python -m src.maxdiffusion.models.train src/maxdiffusion/configs/base_2_base.yml \ run_name=v6e-sd2 \ split_head_dim=True \ attention=flash \ train_new_unet=false \ norm_num_groups=16 \ output_dir=${BASE_OUTPUT_DIR} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ [dcn_data_parallelism=2] \ enable_profiler=True \ skip_first_n_steps_for_profiler=95 \ max_train_steps=${NUM_STEPS} ] write_metrics=True'
Ejecuta el modelo con la secuencia de comandos que creaste en el paso anterior. Debes especificar la marca
--base-docker-image
para usar la imagen base de MaxText o especificar la marca--docker-image
y la imagen que deseas usar.Opcional: Puedes habilitar el registro de depuración si incluyes la marca
--enable-debug-logs
. Para obtener más información, consulta Cómo depurar JAX en MaxText.Opcional: Puedes crear un experimento de Vertex AI para subir datos a Vertex AI TensorBoard si incluyes la marca
--use-vertex-tensorboard
. Para obtener más información, consulta Supervisa JAX en MaxText con Vertex AI.python3 xpk.py workload create \ --cluster CLUSTER_NAME \ {--base-docker-image maxtext_base_image|--docker-image ${CLOUD_IMAGE_NAME}} \ --workload ${USER}-xpk-ACCELERATOR_TYPE-NUM_SLICES \ --tpu-type=ACCELERATOR_TYPE \ --num-slices=NUM_SLICES \ --on-demand \ --zone $ZONE \ --project $PROJECT_ID \ [--enable-debug-logs] \ [--use-vertex-tensorboard] \ --command YOUR_MODEL_SCRIPT
Reemplaza las siguientes variables:
- CLUSTER_NAME: Es el nombre de tu clúster de XPK.
- ACCELERATOR_TYPE: La versión y el tamaño de tu TPU. Por ejemplo,
v6e-256
- NUM_SLICES: Es la cantidad de porciones de TPU.
- YOUR_MODEL_SCRIPT: Es la secuencia de comandos del modelo que se ejecutará como un comando de entrenamiento.
El resultado incluye un vínculo para seguir tu carga de trabajo, similar al siguiente:
[XPK] Follow your workload here: https://console.cloud.google.com/kubernetes/service/zone/project_id/default/workload_name/details?project=project_id
Abre el vínculo y haz clic en la pestaña Registros para hacer un seguimiento de tu carga de trabajo en tiempo real.
Cómo depurar JAX en MaxText
Usa los comandos XPK complementarios para diagnosticar por qué no se ejecuta el clúster o la carga de trabajo:
- Lista de cargas de trabajo de XPK
- Inspector de XPK
- Habilita el registro detallado en tus registros de cargas de trabajo con la marca
--enable-debug-logs
cuando crees la carga de trabajo de XPK.
Supervisa JAX en MaxText con Vertex AI
Consulta los datos escalares y de perfil a través de TensorBoard administrado de Vertex AI.
- Aumenta las solicitudes de administración de recursos (CRUD) de la zona que usas de 600 a 5,000. Esto podría no ser un problema para cargas de trabajo pequeñas que usan menos de 16 VMs.
Instala dependencias como
cloud-accelerator-diagnostics
para Vertex AI:# xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI cd ~/xpk pip install .
Crea tu clúster de XPK con la marca
--create-vertex-tensorboard
, como se documenta en Crea Vertex AI TensorBoard. También puedes ejecutar este comando en clústeres existentes.Crea tu experimento de Vertex AI cuando ejecutes tu carga de trabajo de XPK con la marca
--use-vertex-tensorboard
y la marca opcional--experiment-name
. Para obtener la lista completa de pasos, consulta Crea un experimento de Vertex AI para subir datos a Vertex AI TensorBoard.
Los registros incluyen un vínculo a un Vertex AI TensorBoard, similar al siguiente:
View your TensorBoard at https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name
También puedes encontrar el vínculo de Vertex AI TensorBoard en la console de Google Cloud. Ve a Experimentos de Vertex AI en la consola de Google Cloud. Selecciona la región adecuada en el menú desplegable.
El directorio de TensorBoard también se escribe en el bucket de Cloud Storage que especificaste con ${BASE_OUTPUT_DIR}
.
Borra cargas de trabajo de XPK
Usa el comando xpk workload delete
para borrar una o más cargas de trabajo según el prefijo o el estado del trabajo. Este comando puede ser útil si enviaste cargas de trabajo de XPK que ya no necesitan ejecutarse o si tienes trabajos que están bloqueados en la cola.
Borra el clúster de XPK
Usa el comando xpk cluster delete
para borrar un clúster:
python3 xpk.py cluster delete --cluster CLUSTER_NAME --zone $ZONE --project $PROJECT_ID
Llama y PyTorch
En este instructivo, se describe cómo entrenar modelos de Llama con PyTorch/XLA en TPU v6e con el conjunto de datos de WikiText. Además, los usuarios pueden acceder a las descripciones de modelos de TPU de PyTorch como imágenes de Docker aquí.
Instalación
Instala la bifurcación pytorch-tpu/transformers
de los transformadores de Hugging Face y las dependencias en un entorno virtual:
git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git cd transformers pip3 install -e . pip3 install datasets pip3 install evaluate pip3 install scikit-learn pip3 install accelerate
Configura los parámetros de configuración del modelo
El comando de entrenamiento en la siguiente sección, Crea la secuencia de comandos de tu modelo, usa dos archivos de configuración JSON para definir los parámetros del modelo y la configuración de FSDP (paralelismo de datos completamente fragmentado). El fragmentación de FSDP se usa para que los pesos del modelo se ajusten a un tamaño de lote más grande durante el entrenamiento. Cuando entrenas con modelos más pequeños, puede ser suficiente con usar el paralelismo de datos y replicar los pesos en cada dispositivo. Consulta la Guía del usuario de SPMD de PyTorch/XLA para obtener más detalles sobre cómo dividir tensores en varios dispositivos en PyTorch/XLA.
Crea el archivo de configuración del parámetro del modelo. La siguiente es la configuración del parámetro del modelo para Llama3-8B. Para otros modelos, busca la configuración en Hugging Face. Por ejemplo, consulta la configuración de Llama2-7B.
{ "architectures": [ "LlamaForCausalLM" ], "attention_bias": false, "attention_dropout": 0.0, "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 8192, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "num_key_value_heads": 8, "pretraining_tp": 1, "rms_norm_eps": 1e-05, "rope_scaling": null, "rope_theta": 500000.0, "tie_word_embeddings": false, "torch_dtype": "bfloat16", "transformers_version": "4.40.0.dev0", "use_cache": false, "vocab_size": 128256 }
Crea el archivo de configuración de FSDP:
{ "fsdp_transformer_layer_cls_to_wrap": [ "LlamaDecoderLayer" ], "xla": true, "xla_fsdp_v2": true, "xla_fsdp_grad_ckpt": true }
Consulta FSDPv2 para obtener más detalles sobre el FSDP.
Sube los archivos de configuración a tus VMs de TPU con el siguiente comando:
gcloud alpha compute tpus tpu-vm scp YOUR_CONFIG_FILE.json $TPU_NAME:. \ --worker=all \ --project=$PROJECT \ --zone $ZONE
También puedes crear los archivos de configuración en tu directorio de trabajo actual y usar la marca
--base-docker-image
en XPK.
Compila la secuencia de comandos de tu modelo
Compila la secuencia de comandos del modelo y especifica el archivo de configuración del parámetro del modelo con la marca --config_name
y el archivo de configuración de FSDP con la marca --fsdp_config
.
Ejecutarás esta secuencia de comandos en tu TPU en la siguiente sección, Ejecuta el
modelo. Aún no ejecutes la secuencia de comandos del modelo.
PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true # Optional variables for debugging: XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 PROFILE_EPOCH=0 PROFILE_STEP=3 PROFILE_DURATION_MS=100000 PROFILE_LOGDIR=local VM path or gs://my-bucket/profile_path python3 transformers/examples/pytorch/language-modeling/run_clm.py \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 8 \ --do_train \ --output_dir /home/$USER/tmp/test-clm \ --overwrite_output_dir \ --config_name /home/$USER/config-8B.json \ --cache_dir /home/$USER/cache \ --tokenizer_name meta-llama/Meta-Llama-3-8B \ --block_size 8192 \ --optim adafactor \ --save_strategy no \ --logging_strategy no \ --fsdp "full_shard" \ --fsdp_config /home/$USER/fsdp_config.json \ --torch_dtype bfloat16 \ --dataloader_drop_last yes \ --flash_attention \ --max_steps 20
Ejecuta el modelo
Ejecuta el modelo con la secuencia de comandos que creaste en el paso anterior, Compila la secuencia de comandos de tu modelo.
Si usas una VM de TPU de host único (como v6e-4
), puedes ejecutar el comando de entrenamiento directamente en la VM de TPU. Si usas una VM de TPU de varios hosts, usa el siguiente comando para ejecutar la secuencia de comandos de forma simultánea en todos los hosts:
gcloud alpha compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT \ --zone $ZONE \ --worker=all \ --command=YOUR_COMMAND
Soluciona problemas de PyTorch/XLA
Si configuras las variables opcionales para la depuración en la sección anterior, el perfil del modelo se almacenará en la ubicación que especifique la variable PROFILE_LOGDIR
. Puedes extraer el archivo xplane.pb
almacenado en esta ubicación y usar tensorboard
para ver los perfiles en tu navegador con las instrucciones de TensorBoard. Si PyTorch/XLA no funciona como se espera, consulta la guía de solución de problemas, que tiene sugerencias para depurar, generar perfiles y optimizar tus modelos.
Instructivo de DLRM DCN v2
En este instructivo, se muestra cómo entrenar el modelo DLRM DCN v2 en TPU v6e.
Si ejecutas en varios hosts, restablece tpu-runtime
con la versión correcta de TensorFlow ejecutando el siguiente comando. Si ejecutas la prueba en un solo host, no necesitas ejecutar los siguientes dos comandos.
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID}
--zone ${ZONE} --worker=all \
--command="sudo sed -i 's/TF_DOCKER_URL=.*/TF_DOCKER_URL=gcr.io\/cloud-tpu-v2-images\/grpc_tpu_worker:v6e\"/' /etc/systemd/system/tpu-runtime.service"
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --project ${PROJECT_ID} \
--zone ${ZONE} \
--worker=all \
--command='sudo systemctl daemon-reload && sudo systemctl restart tpu-runtime'
Establece una conexión SSH a worker-0
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone ${ZONE} --project {$PROJECT_ID}
Establece el nombre de la TPU
export TPU_NAME=${TPU_NAME}
Ejecuta DLRM v2
pip install cloud-tpu-client
pip install gin-config && pip install tensorflow-datasets && pip install tf-keras-nightly --no-deps
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
git clone https://github.com/tensorflow/recommenders.git
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=~/recommenders/:~/models/
export TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true --tf_xla_sparse_core_disable_table_stacking=true --tf_mlir_enable_convert_control_to_data_outputs_pass=true --tf_mlir_enable_merge_control_flow_pass=true'
TF_USE_LEGACY_KERAS=1 TPU_LOAD_LIBRARY=0 python3 ./models/official/recommendation/ranking/train.py --mode=train --model_dir=gs://ptxla-debug/tf/sc/dlrm/runs/2/ --params_override="
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'mixed_bfloat16'
task:
use_synthetic_data: false
use_tf_record_reader: true
train_data:
input_path: 'gs://trillium-datasets/criteo/train/day_*/*'
global_batch_size: 16384
use_cached_data: true
validation_data:
input_path: 'gs://trillium-datasets/criteo/eval/day_*/*'
global_batch_size: 16384
use_cached_data: true
model:
num_dense_features: 13
bottom_mlp: [512, 256, 128]
embedding_dim: 128
interaction: 'multi_layer_dcn'
dcn_num_layers: 3
dcn_low_rank_dim: 512
size_threshold: 8000
top_mlp: [1024, 1024, 512, 256, 1]
use_multi_hot: true
concat_dense: false
dcn_use_bias: true
vocab_sizes: [40000000,39060,17295,7424,20265,3,7122,1543,63,40000000,3067956,405282,10,2209,11938,155,4,976,14,40000000,40000000,40000000,590152,12973,108,36]
multi_hot_sizes: [3,2,1,2,6,1,1,1,1,7,3,8,1,6,9,5,1,1,1,12,100,27,10,3,1,1]
max_ids_per_chip_per_sample: 128
max_ids_per_table: [280, 128, 64, 272, 432, 624, 64, 104, 368, 352, 288, 328, 304, 576, 336, 368, 312, 392, 408, 552, 2880, 1248, 720, 112, 320, 256]
max_unique_ids_per_table: [104, 56, 40, 32, 72, 32, 40, 32, 32, 144, 64, 192, 32, 40, 136, 32, 32, 32, 32, 240, 1352, 432, 120, 80, 32, 32]
use_partial_tpu_embedding: false
size_threshold: 0
initialize_tables_on_host: true
trainer:
train_steps: 10000
validation_interval: 1000
validation_steps: 660
summary_interval: 1000
steps_per_loop: 1000
checkpoint_interval: 0
optimizer_config:
embedding_optimizer: 'Adagrad'
dense_optimizer: 'Adagrad'
lr_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.025
warmup_steps: 0
dense_sgd_config:
decay_exp: 2
decay_start_steps: 70000
decay_steps: 30000
learning_rate: 0.00025
warmup_steps: 8000
train_tf_function: true
train_tf_while_loop: true
eval_tf_while_loop: true
use_orbit: true
pipeline_sparse_and_dense_execution: true"
Ejecuta script.sh
:
chmod +x script.sh
./script.sh
pip install https://storage.googleapis.com/tensorflow-public-build-artifacts/prod/tensorflow/official/release/nightly/linux_x86_tpu/wheel_py310/749/20240915-062017/github/tensorflow/build_output/tf_nightly_tpu-2.18.0.dev20240915-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl \
-f https://storage.googleapis.com/libtpu-tf-releases/index.html --force
Las siguientes marcas son necesarias para ejecutar cargas de trabajo de recomendación (DCN de DLRM):
ENV TF_XLA_FLAGS='--tf_mlir_enable_mlir_bridge=true \
--tf_mlir_enable_tpu_variable_runtime_reformatting_pass=false \
--tf_mlir_enable_convert_control_to_data_outputs_pass=true \
--tf_mlir_enable_merge_control_flow_pass=true --tf_xla_disable_full_embedding_pipelining=true' \
ENV LIBTPU_INIT_ARGS="--xla_sc_splitting_along_feature_dimension=auto \
--copy_with_dynamic_shape_op_output_pjrt_buffer=true"
Resultados de comparativas
En la siguiente sección, se incluyen los resultados de las comparativas de DLRM DCN v2 y MaxDiffusion en v6e.
DLRM DCN v2
La secuencia de comandos de entrenamiento de DLRM DCN v2 se ejecutó a diferentes escalas. Consulta las tasas de transferencia en la siguiente tabla.
v6e-64 | v6e-128 | v6e-256 | |
Pasos de entrenamiento | 7000 | 7000 | 7000 |
Tamaño del lote global | 131072 | 262144 | 524288 |
Capacidad de procesamiento (ejemplos/s) | 2975334 | 5111808 | 10066329 |
MaxDiffusion
Ejecutamos la secuencia de comandos de entrenamiento para MaxDiffusion en una v6e-4, una v6e-16 y una 2xv6e-16. Consulta las tasas de transferencia en la siguiente tabla.
v6e-4 | v6e-16 | Dos v6e-16 | |
Pasos de entrenamiento | 0.069 | 0.073 | 0.13 |
Tamaño del lote global | 8 | 32 | 64 |
Capacidad de procesamiento (ejemplos/s) | 115.9 | 438.4 | 492.3 |
Colecciones
La versión 6e presenta una nueva función llamada colecciones para beneficiar a los usuarios que ejecutan cargas de trabajo de publicación. La función de colecciones solo se aplica a la versión 6e.
Las colecciones te permiten indicarle a Google Cloud cuáles de tus nodos de TPU forman parte de una carga de trabajo de publicación. Esto permite que la infraestructura subyacente de Google Cloud limite y optimice las interrupciones que se pueden aplicar a las cargas de trabajo de entrenamiento en el curso normal de las operaciones.
Usa colecciones de la API de Cloud TPU
Una colección de un solo host en la API de Cloud TPU es un recurso en cola en el que se establece una marca especial (--workload-type = availability-optimized
) para indicarle a la infraestructura subyacente que se debe usar para entregar cargas de trabajo.
El siguiente comando aprovisiona una colección de host único con la API de Cloud TPU:
gcloud alpha compute tpus queued-resources create COLLECTION_NAME \ --project=project name \ --zone=zone name \ --accelerator-type=accelerator type \ --node-count=number of nodes \ --workload-type=availability-optimized
Supervisa y crea perfiles
Cloud TPU v6e admite la supervisión y la generación de perfiles con los mismos métodos que las generaciones anteriores de Cloud TPU. Para obtener más información sobre la supervisión, consulta Supervisa VMs de TPU.