Cómo conservar el progreso del entrenamiento con Autocheckpoint

Históricamente, cuando una VM de TPU requiere mantenimiento, el procedimiento se inicia de inmediato, sin dejar tiempo para que los usuarios realicen acciones que preserven el progreso, como guardar un punto de control. Esto se muestra en la Figura 1(a).

Diagrama que muestra el impacto del mantenimiento del host con y sin puntos de control automáticos

Figura 1. Ilustración de la función Autocheckpoint: (a) Sin Autocheckpoint, se pierde el progreso del entrenamiento del último punto de control cuando hay un evento de mantenimiento próximo. (b) Con Autocheckpoint, el progreso del entrenamiento desde el último punto de control se puede conservar cuando hay un evento de mantenimiento próximo.

Puedes usar el punto de control automático (Figura 1(b)) para preservar el progreso del entrenamiento. Para ello, configura tu código para que guarde un punto de control no programado cuando se produzca un evento de mantenimiento. Cuando se produce un evento de mantenimiento, el progreso desde el último punto de control se guarda automáticamente. La función funciona en porciones únicas y en Multislice.

La función Autocheckpoint funciona con frameworks que pueden capturar señales SIGTERM y, luego, guardar un punto de control. Los frameworks compatibles son los siguientes:

Cómo usar Autocheckpoint

La función de punto de control automático está inhabilitada de forma predeterminada. Cuando creas una TPU o solicitas un recurso en cola, puedes habilitar el punto de control automático agregando la marca --autocheckpoint-enabled cuando aprovisiones la TPU. Con la función habilitada, Cloud TPU realiza los siguientes pasos una vez que recibe la notificación de un evento de mantenimiento:

  1. Captura la señal SIGTERM que se envía al proceso con el dispositivo TPU
  2. Espera hasta que se cierre el proceso o transcurran 5 minutos, lo que ocurra primero.
  3. Realiza el mantenimiento de las porciones afectadas

La infraestructura que usa Autocheckpoint es independiente del framework de AA. Cualquier framework de AA puede admitir Autocheckpoint si puede capturar la señal SIGTERM e iniciar un proceso de puntos de control.

En el código de la aplicación, debes habilitar las funciones de punto de control automático que proporciona el framework de AA. En Pax, por ejemplo, esto significa habilitar las marcas de línea de comandos cuando se inicia el entrenamiento. Para obtener más información, consulta la guía de inicio rápido de Autocheckpoint con Pax. En segundo plano, los frameworks guardan un punto de control no programado cuando se recibe una señal SIGTERM, y la VM de TPU afectada pasa por un mantenimiento cuando la TPU ya no está en uso.

Guía de inicio rápido: Punto de control automático con MaxText

MaxText es un LLM de alto rendimiento, escalable de forma arbitraria, de código abierto y bien probado, escrito en Python/JAXpuro orientado a Cloud TPU. MaxText contiene toda la configuración necesaria para usar la función Autocheckpoint.

En el archivo README de MaxText, se describen dos maneras de ejecutar MaxText a gran escala:

Cuando uses multihost_runner.py, habilita el punto de control automático configurando la marca autocheckpoint-enabled cuando aprovisiones el recurso en cola.

Cuando uses multihost_job.py, habilita el punto de control automático especificando la marca de línea de comandos ENABLE_AUTOCHECKPOINT=true cuando inicies el trabajo.

Guía de inicio rápido: Punto de control automático con Pax en una sola porción

En esta sección, se proporciona un ejemplo de cómo configurar y usar Autocheckpoint con Pax en una sola porción. Con la configuración adecuada, sucede lo siguiente:

  • Se guardará un punto de control cuando se produzca un evento de mantenimiento.
  • Cloud TPU realizará el mantenimiento de las VMs de TPU afectadas después de que se guarde el punto de control.
  • Cuando Cloud TPU complete el mantenimiento, podrás usar la VM de TPU como de costumbre.
  1. Usa la marca autocheckpoint-enabled cuando crees la VM de TPU o solicites un recurso en fila.

    Por ejemplo:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
        --accelerator-type $ACCELERATOR_TYPE \
        --version tpu-ubuntu2204-base \
        --autocheckpoint-enabled
  2. Instala Pax en una sola porción

    La función Autocheckpoint funciona en Pax 1.1.0 y versiones posteriores. En las VMs de TPU, instala jax[tpu] y la versión más reciente de paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. Inicia el entrenamiento con la configuración adecuada.

    En el siguiente ejemplo, se muestra cómo configurar el modelo LmCloudSpmd2B para guardar los puntos de control activados por Autocheckpoint en un bucket de Cloud Storage:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Ten en cuenta las dos marcas que se pasan al comando:

    • jax_fully_async_checkpoint: Con esta marca activada, se usará orbax.checkpoint.AsyncCheckpointer. La clase AsyncCheckpointer guarda automáticamente un punto de control cuando la secuencia de comandos de entrenamiento recibe un indicador SIGTERM.
    • exit_after_ondemand_checkpoint: Con esta marca activada, el proceso de TPU finaliza después de que el punto de control automático se guarda correctamente, lo que activa el mantenimiento para que se realice de inmediato. Si no usas esta marca, el entrenamiento continuará después de que se guarde el punto de control, y Cloud TPU esperará a que se produzca un tiempo de espera (5 minutos) antes de realizar el mantenimiento necesario.

Guía de inicio rápido: Punto de control automático con Pax en Multislice

El punto de control automático funciona no solo para una sola porción, sino también para Multislice. En esta sección, se detallan los pasos necesarios para usar Autocheckpoint con Multislice.

  1. Especifica el punto de control automático durante la creación de recursos en cola.

    Un entorno de Multislice solo se puede aprovisionar a través de una solicitud de recurso en fila. Al igual que en el caso de una sola porción, usa la marca autocheckpoint-enabled en la llamada para crear un recurso en fila.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud compute tpus queued-resources create $QR_ID \
        --node-count $NODE_COUNT \
        --accelerator-type $ACCELERATOR_TYPE \
        --runtime-version tpu-ubuntu2204-base \
        --autocheckpoint-enabled

    Para obtener más información sobre todas las opciones disponibles, consulta la Guía del usuario de Multislice. Cuando se crea la solicitud de recursos en cola y está en el estado ACTIVE, sigue los pasos que se indican a continuación para ejecutar Pax con Autocheckpoint.

  2. Instala Pax en todas las VMs del entorno de Multislice.

    En las VMs de TPU, instala jax[tpu] y la versión más reciente de paxml en todas las VMs de TPU de tu entorno de Multislice:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. Inicia el entrenamiento con la configuración adecuada.

    En este ejemplo, se muestra cómo configurar el modelo LmCloudSpmd2B para el punto de control automático cuando se entrena en un entorno de Multislice. Antes de ejecutar la secuencia de comandos de entrenamiento, establece DCN_MESH_SHAPE en [2, 1, 1], como se muestra en el siguiente ejemplo:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 4, 1]
        DCN_MESH_SHAPE = [2, 1, 1]

    Cuando se inicia el entrenamiento, además de las marcas de línea de comandos que se analizaron en el caso de una sola porción, se requieren tres más:

    • num_hosts: Es la cantidad total de hosts. En este caso, es 2.
    • host_index: Es el índice del host que inicia el entrenamiento. Varía de 0 a N-1, donde N es la cantidad total de hosts.
    • server_addr: Es la dirección IP del trabajador 0 del nodo 0, con un puerto que no se usa (por ejemplo, 8476). Para encontrar esta información, usa hostname -i en el trabajador 0 del nodo 0.

Punto de control automático con Orbax

La función de punto de control automático no se limita a MaxText o Pax. Cualquier framework que pueda capturar la señal SIGTERM e iniciar un proceso de creación de puntos de control funciona con la infraestructura que proporciona Autocheckpoint. Orbax, un espacio de nombres que proporciona bibliotecas de utilidad comunes para los usuarios de JAX, proporciona estas funciones.

Como se explica en la documentación de Orbax, estas funciones están habilitadas de forma predeterminada para los usuarios de orbax.checkpoint.CheckpointManager. El método save al que se llama después de cada paso verifica automáticamente si hay un evento de mantenimiento inminente y, de ser así, guarda un punto de control, incluso si el número de paso no es un múltiplo de save_interval_steps. En la documentación de GitHub, también se muestra cómo hacer que el entrenamiento finalice después de guardar un punto de control automático, con una modificación en el código del usuario.