Punto de control automático de Cloud TPU [versión preliminar pública]

Descripción general

Históricamente, cuando una VM de TPU requiere mantenimiento, el procedimiento se inicia de inmediato, sin dejar tiempo para que los usuarios realicen acciones de preservación del progreso, como guardar un punto de control. Esto se muestra en la Figura 1(a).

punto de control automático

Figura 1: Ilustración de la función Punto de control automático: (a) Sin el punto de control automático, el progreso del entrenamiento desde el último punto de control se pierde cuando se produce un próximo evento de mantenimiento. (b) Con el punto de control automático, el progreso del entrenamiento desde el último punto de control se puede conservar cuando hay un próximo evento de mantenimiento.

Puedes usar el punto de control automático (Figura 1(b)) a fin de conservar el progreso del entrenamiento. Para ello, configura tu código de modo que guarde un punto de control no programado cuando se produzca un evento de mantenimiento. Cuando ocurre un evento de mantenimiento, se guarda de forma automática el progreso desde el último punto de control. La característica funciona tanto en porciones individuales como en Multislice.

La función Punto de control automático funciona con frameworks que pueden capturar SIGTERM y, luego, guardar un punto de control. Los frameworks compatibles incluyen MaxText, Pax y JAX con Orbax. La compatibilidad con frameworks adicionales se anunciará a medida que estén disponibles.

Por el momento, solo las TPU (v2-v4 y v5e) creadas a través de la API de Cloud TPU pueden usar esta función. La compatibilidad con las TPU en GKE se anunciará cuando esté disponible.

Usa el punto de control automático

La funcionalidad de punto de control automático está inhabilitada de forma predeterminada. Cuando creas una TPU o un recurso en cola, puedes habilitarlo agregando la marca --autocheckpoint-enabled cuando aprovisionas la TPU. Con la función habilitada, Cloud TPU realiza los siguientes pasos una vez que recibe la notificación de un evento de mantenimiento:

  1. Capturar SIGTERM enviado al proceso con el dispositivo de TPU,
  2. Espera hasta que finalice el proceso o hayan transcurrido 5 minutos, lo que ocurra primero, y realiza el mantenimiento en las porciones afectadas.

Ten en cuenta que la infraestructura que usa Autocheckpoint es independiente del framework del AA. Cualquier framework de AA puede admitir un punto de control automático, siempre que pueda capturar la señal SIGTERM e iniciar un proceso de punto de control.

En el código de la aplicación, debes habilitar las capacidades del punto de control automático que proporciona el framework del AA. En Pax, por ejemplo, esto significa habilitar las marcas de línea de comandos cuando se inicia el entrenamiento (consulta la guía de inicio rápido del punto de control automático con Pax). En segundo plano, los frameworks guardan un punto de control no programado cuando se recibe un SIGTERM y la VM de TPU afectada pasa por mantenimiento cuando la TPU ya no está en uso.

Guía de inicio rápido: Punto de control automático con MaxText

MaxText es un “LLM de alto rendimiento, arbitrariamente escalable, de código abierto y bien probado que está escrito en Python/JAX puro y orientado a las Cloud TPU”. MaxText contiene toda la configuración necesaria para usar la función de punto de control automático.

El archivo README de MaxText describe dos formas de ejecutar MaxText a gran escala:

Cuando usas multihost_runner.py, el único cambio necesario es configurar la marca autocheckpoint-enabled cuando se aprovisiona el recurso en cola. Cuando se usa multihost_job.py, el único cambio requerido es especificar la marca de línea de comandos ENABLE_AUTOCHECKPOINT=true cuando se inicia el trabajo.

Guía de inicio rápido: Punto de control automático con Pax en porciones individuales

En esta sección, proporcionamos un ejemplo de cómo configurar y usar el punto de control automático con Pax en una sola porción. Con la configuración adecuada, haz lo siguiente:

  • Se guardará un punto de control cuando ocurra un evento de mantenimiento.
  • Cloud TPU realizará el mantenimiento en las VM de TPU afectadas después de que se guarde el punto de control.
  • Cuando Cloud TPU complete el mantenimiento, puedes usar la VM de TPU como de costumbre.
  1. Usa la marca autocheckpoint-enabled cuando crees la VM de TPU o el recurso en cola.

    Por ejemplo:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    
  2. Instalar Pax en una sola porción

    La función Punto de control automático funciona en versiones de Pax >= 1.1.0. En las VM de TPU, instala jax[tpu] y la última versión de paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Iniciar el entrenamiento con la configuración adecuada

    En el siguiente ejemplo, se muestra cómo configurar el modelo LmCloudSpmd2B para guardar puntos de control activados por Autocheckpoint en un bucket de Google Cloud Storage:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
    

    Observa las dos marcas que se pasan al comando:

    • jax_fully_async_checkpoint: Si esta marca está activada, se usará orbax.checkpoint.AsyncCheckpointer. La clase AsyncCheckpointer guarda automáticamente un punto de control cuando la secuencia de comandos de entrenamiento recibe una señal SIGTERM.
    • exit_after_ondemand_checkpoint: Con esta marca activada, los procesos de TPU salen después de que el punto de control automático se guarda de forma correcta, lo que activa el mantenimiento para que se realice de inmediato. Si no usas esta marca, el entrenamiento continuará después de que se guarde el punto de control y Cloud TPU esperará a que se agote el tiempo de espera (5 minutos) antes de realizar el mantenimiento requerido.

Guía de inicio rápido: Punto de control automático con Pax en Multislice

El punto de control automático funciona no solo para porciones individuales, sino también para Multislice. En esta sección, se detallan los pasos necesarios para usar el punto de control automático con Multislice.

  1. Especifica el punto de control automático durante la creación de recursos en cola.

    Un entorno de varias porciones solo se puede aprovisionar a través de una solicitud de recursos en cola. Al igual que con el caso de una sola sección, usa la marca autocheckpoint-enabled en la llamada para crear un recurso en cola.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud alpha compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    

    Consulta la Guía del usuario de Multislice para obtener detalles sobre todas las opciones disponibles. Una vez que se crea la solicitud de recursos en cola y en el estado ACTIVE, sigue los siguientes pasos para ejecutar Pax con el punto de control automático.

  2. Instalar Pax en todas las VM del entorno Multislice.

    En las VM de TPU, instala jax[tpu] y la versión paxml más reciente en todas las VM de TPU del entorno Multislice:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Iniciar el entrenamiento con la configuración adecuada

    En este ejemplo, se muestra cómo configurar el modelo LmCloudSpmd2B para el punto de control automático cuando se entrena en un entorno Multislice. Antes de ejecutar la secuencia de comandos de entrenamiento, configura DCN_MESH_SHAPE en [2, 1, 1] como se muestra en el siguiente código:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]
    

    Cuando se inicia el entrenamiento, además de las marcas de línea de comandos que se analizan en el caso de una sola sección, se requieren tres más:

    • num_hosts: Es la cantidad total de hosts. En este caso, es 2.
    • host_index: Es el índice del host que inicia el entrenamiento. Varía de 0 a N-1, en el que N es la cantidad total de hosts.
    • server_addr: Es la dirección IP del trabajador 0 del nodo 0, con un puerto sin usar (por ejemplo, 8476). Para encontrar esta información, usa hostname -i en el trabajador 0 del nodo 0.

Punto de control automático con Orbax

La función Punto de control automático no se limita a MaxText o Pax. Cualquier framework que pueda capturar la señal SIGTERM e iniciar un proceso de punto de control funciona con la infraestructura que proporciona Autocheckpoint. Orbax, un espacio de nombres que proporciona bibliotecas de utilidades comunes para usuarios de JAX, proporciona estas funciones.

Como se explica en la documentación de Orbax, estas funciones están habilitadas de forma predeterminada para los usuarios de orbax.checkpoint.CheckpointManager. El método save al que se llama después de cada paso verifica de forma automática si un evento de mantenimiento es inminente y, de ser así, guarda un punto de control incluso si el número de paso no es múltiplo de save_interval_steps. En la documentación de GitHub también se muestra cómo realizar la salida del entrenamiento después de guardar un punto de control automático, con una modificación en el código de usuario.