Checkpoint automatico di Cloud TPU [anteprima pubblica]

Panoramica

Storicamente, quando una VM TPU richiede la manutenzione, la procedura viene avviata immediatamente, senza lasciare tempo agli utenti per eseguire azioni che garantiscono l'avanzamento, come il salvataggio di un checkpoint. Ciò è mostrato nella Figura 1(a).

checkpoint automatico

Fig. 1. Illustrazione della funzionalità Autocheckpoint: (a) Senza Autocheckpoint, l'avanzamento dell'addestramento dall'ultimo checkpoint viene perso quando si verifica un evento di manutenzione imminente. (b) Con il checkpoint automatico, l'avanzamento dell'addestramento dall'ultimo checkpoint può essere conservato in caso di un evento di manutenzione imminente.

Puoi utilizzare il punto di controllo automatico (Figura 1(b)) per preservare l'avanzamento dell'addestramento configurando il codice per salvare un checkpoint non pianificato quando si verifica un evento di manutenzione. Quando si verifica un evento di manutenzione, l'avanzamento dall'ultimo checkpoint viene salvato automaticamente. Questa funzionalità può essere usata sia sia su una sezione singola che su più sezioni.

La funzionalità Autocheckpoint funziona con framework in grado di acquisire SIGTERM e successivamente salvare un checkpoint. I framework supportati includono MaxText, Pax e JAX con Orbax. Il supporto per framework aggiuntivi verrà annunciato non appena saranno disponibili.

Al momento solo le TPU (v2-v4 e v5e) create tramite l'API Cloud TPU possono utilizzare questa funzionalità. Il supporto per le TPU in GKE verrà annunciato quando sarà disponibile.

Utilizzo di Autocheckpoint

La funzionalità checkpoint automatico è disattivata per impostazione predefinita. Quando crei una TPU o una risorsa in coda, puoi abilitarla aggiungendo il flag --autocheckpoint-enabled durante il provisioning della TPU. Con la funzionalità abilitata, Cloud TPU esegue i seguenti passaggi quando riceve la notifica di un evento di manutenzione:

  1. Acquisisci SIGTERM inviato al processo utilizzando il dispositivo TPU,
  2. Attende fino all'uscita del processo o dopo che sono trascorsi 5 minuti, a seconda dell'evento che si verifica per primo, ed esegue la manutenzione delle sezioni interessate.

Tieni presente che l'infrastruttura utilizzata da Autocheckpoint è indipendente dal framework ML. Qualsiasi framework ML può supportare Autocheckpoint a condizione di poter acquisire il segnale SIGTERM e avviare un processo di checkpointing.

Nel codice dell'applicazione, devi abilitare le funzionalità Autocheckpoint fornite dal framework ML. In Pax, ad esempio, ciò significa abilitare i flag della riga di comando all'avvio dell'addestramento (consulta la guida rapida di Autocheckpoint con Pax). Dietro le quinte, i framework salvano un checkpoint non pianificato alla ricezione di un SIGTERM e la VM TPU interessata viene sottoposta a manutenzione quando la TPU non è più in uso.

Guida rapida: checkpoint automatico con MaxText

MaxText è un "LLM open source, ad alte prestazioni, scalabile e ben collaudato, scritto in Python/JAX puro per il targeting di Cloud TPU". MaxText contiene tutte le configurazioni necessarie per utilizzare la funzionalità Autocheckpoint.

Il file README di MaxText descrive due modi per eseguire MaxText su larga scala:

Quando utilizzi multihost_runner.py, l'unica modifica richiesta è impostare il flag autocheckpoint-enabled durante il provisioning della risorsa in coda. Quando utilizzi multihost_job.py, l'unica modifica necessaria è specificare il flag della riga di comando ENABLE_AUTOCHECKPOINT=true all'avvio del job.

Guida rapida: checkpoint automatico con Pax su sezioni singole

In questa sezione, forniamo un esempio di come configurare e utilizzare Autocheckpoint con Pax su una singola sezione. Con la configurazione appropriata:

  • Quando si verifica un evento di manutenzione, verrà salvato un checkpoint.
  • Cloud TPU eseguirà la manutenzione delle VM TPU interessate dopo il salvataggio del checkpoint.
  • Quando Cloud TPU completa la manutenzione, puoi utilizzare la VM TPU come di consueto.
  1. Utilizza il flag autocheckpoint-enabled durante la creazione della VM TPU o della risorsa in coda.

    Ad esempio:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    
  2. Installa Pax su una singola sezione

    La funzionalità Autocheckpoint funziona su versioni di Pax successive alla 1.1.0. Sulle VM TPU, installa jax[tpu] e la versione più recente di paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Avvia la formazione con la configurazione appropriata

    L'esempio seguente mostra come configurare il modello LmCloudSpmd2B per salvare i checkpoint attivati da Autocheckpoint in un bucket Google Cloud Storage:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
    

    Nota i due flag che vengono passati al comando:

    • jax_fully_async_checkpoint: con questo flag attivo, verrà utilizzato orbax.checkpoint.AsyncCheckpointer. La classe AsyncCheckpointer salva automaticamente un checkpoint quando lo script di addestramento riceve un segnale SIGTERM.
    • exit_after_ondemand_checkpoint: se questo flag è attivo, la TPU elabora le uscite dopo il salvataggio corretto del checkpoint automatico, che attiva l'esecuzione immediata della manutenzione. Se non utilizzi questo flag, l'addestramento continuerà dopo il salvataggio del checkpoint e Cloud TPU attenderà il verificarsi di un timeout (5 minuti) prima di eseguire la manutenzione richiesta.

Guida rapida: checkpoint automatico con Pax su Multislice

Il punto di controllo automatico funziona non solo per le sezioni singole, ma anche per Multisezione. Questa sezione descrive i passaggi necessari per utilizzare Autocheckpoint con Multislice.

  1. Specifica il punto di controllo automatico durante la creazione delle risorse in coda.

    È possibile eseguire il provisioning di un ambiente multislice solo tramite una richiesta di risorse in coda. Come per il caso a sezione singola, utilizza il flag autocheckpoint-enabled nella chiamata per creare una risorsa in coda.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud alpha compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    

    Per i dettagli su tutte le opzioni disponibili, consulta la guida dell'utente di Multislice. Dopo aver creato la richiesta di risorse in coda e nello stato ACTIVE, segui i passaggi successivi per eseguire Pax con Autocheckpoint.

  2. Installare Pax su tutte le VM nell'ambiente Multislice.

    Sulle VM TPU, installa jax[tpu] e la versione più recente di paxml su tutte le VM TPU nel tuo ambiente multislice:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Avvia la formazione con la configurazione appropriata

    Questo esempio mostra come configurare il modello LmCloudSpmd2B per il checkpoint automatico durante l'addestramento in un ambiente multisezione. Prima di eseguire lo script di addestramento, imposta DCN_MESH_SHAPE su [2, 1, 1] come mostrato nel codice seguente:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]
    

    Quando avvii l'addestramento, oltre ai flag della riga di comando discussi nel caso single-slice, sono richiesti altri tre:

    • num_hosts: il numero totale di host. In questo caso, è 2.
    • host_index: l'indice dell'host che avvia l'addestramento. Varia da 0 a N-1, dove N è il numero totale di host.
    • server_addr: l'indirizzo IP del worker 0 del nodo 0, con una porta inutilizzata (ad esempio, 8476). Per trovare queste informazioni, utilizza hostname -i sul worker 0 del nodo 0.

Checkpoint automatico con Orbax

La funzionalità Checkpoint automatico non è limitata a MaxText o Pax. Qualsiasi framework in grado di acquisire il segnale SIGTERM e avviare un processo di checkpoint funziona con l'infrastruttura fornita da Autocheckpoint. Orbax, uno spazio dei nomi che fornisce librerie di utilità comuni per gli utenti JAX, offre queste funzionalità.

Come spiegato nella documentazione di Orbax, queste funzionalità sono attive per impostazione predefinita per gli utenti di orbax.checkpoint.CheckpointManager. Il metodo save richiamato dopo ogni passaggio controlla automaticamente se un evento di manutenzione è imminente e, in questo caso, salva un checkpoint anche se il numero di passaggio non è un multiplo di save_interval_steps. La documentazione di GitHub illustra anche come uscire dall'addestramento dopo aver salvato un checkpoint automatico, con una modifica al codice utente.