Checkpoint automatico Cloud TPU [Anteprima pubblica]

Panoramica

In passato, quando una VM TPU richiede la manutenzione, la procedura viene avviata immediatamente, senza lasciare tempo agli utenti per eseguire azioni che garantiscono l'avanzamento come il salvataggio di un checkpoint. Questo è mostrato nella Figura 1(a).

checkpoint automatico

Fig. 1. Illustrazione della funzionalità Punto di controllo automatico: (a) Senza checkpoint automatico, l'avanzamento dell'addestramento dall'ultimo checkpoint viene perso quando si verifica un evento di manutenzione imminente. (b) Con Autocheckpoint, l'avanzamento dell'addestramento dall'ultimo checkpoint può essere mantenuto quando si verifica un evento di manutenzione imminente.

Puoi utilizzare Autocheckpoint (Figura 1(b)) per preservare l'avanzamento dell'addestramento configurando il codice in modo da salvare un checkpoint non pianificato quando si verifica un evento di manutenzione. Quando si verifica un evento di manutenzione, l'avanzamento dall'ultimo checkpoint viene salvato automaticamente. La funzionalità è compatibile sia con le sezioni singole sia con più sezioni.

La funzionalità Autocheckpoint funziona con framework in grado di acquisire SIGTERM e successivamente salvare un checkpoint. I framework supportati includono MaxText, Pax e JAX con Orbax. Il supporto di framework aggiuntivi verrà annunciato non appena saranno disponibili.

Per il momento solo le TPU (v2-v4 e v5e) create tramite l'API Cloud TPU possono utilizzare questa funzionalità. Il supporto per le TPU in GKE verrà annunciato quando sarà disponibile.

Utilizzo di Autocheckpoint

La funzionalità del checkpoint automatico è disattivata per impostazione predefinita. Quando crei una TPU o una risorsa in coda, puoi abilitarla aggiungendo il flag --autocheckpoint-enabled durante il provisioning della TPU. Con la funzionalità abilitata, Cloud TPU esegue i seguenti passaggi dopo aver ricevuto la notifica di un evento di manutenzione:

  1. Acquisisci SIGTERM inviato al processo utilizzando il dispositivo TPU,
  2. Attende la chiusura del processo o non sono trascorsi 5 minuti, a seconda dell'evento che si verifica per primo, ed esegue la manutenzione delle sezioni interessate.

Tieni presente che l'infrastruttura utilizzata da Autocheckpoint è indipendente dal framework ML. Qualsiasi framework ML può supportare Autocheckpoint purché sia in grado di acquisire l'indicatore SIGTERM e avviare un processo di checkpoint.

Nel codice dell'applicazione, devi abilitare le funzionalità Autocheckpoint fornite dal framework ML. In Pax, ad esempio, questo significa abilitare i flag della riga di comando quando si avvia l'addestramento (consulta la guida rapida di autocheckpoint con Pax). Dietro le quinte, i framework salvano un checkpoint non pianificato quando viene ricevuto un SIGTERM e la VM TPU interessata viene sottoposta a manutenzione quando la TPU non è più in uso.

Guida rapida: Autocheckpoint con MaxText

MaxText è un "LLM ad alte prestazioni, scalabile arbitrariamente, open source e ben collaudato, scritto in puro Python/JAX che ha come target le Cloud TPU". MaxText contiene tutta la configurazione necessaria per utilizzare la funzionalità Autocheckpoint.

Il file README di MaxText descrive due modi per eseguire MaxText su larga scala:

Quando utilizzi multihost_runner.py, l'unica modifica necessaria è l'impostazione del flag autocheckpoint-enabled durante il provisioning della risorsa in coda. Quando utilizzi multihost_job.py, l'unica modifica necessaria è specificare il flag della riga di comando ENABLE_AUTOCHECKPOINT=true all'avvio del job.

Guida rapida: Autocheckpoint con Pax su sezioni singole

In questa sezione, forniamo un esempio di come configurare e utilizzare Autocheckpoint con Pax su una singola sezione. Con la configurazione appropriata:

  • Verrà salvato un checkpoint quando si verifica un evento di manutenzione.
  • Cloud TPU eseguirà la manutenzione sulle VM TPU interessate dopo il salvataggio del checkpoint.
  • Quando Cloud TPU completa la manutenzione, puoi utilizzare la VM TPU come di consueto.
  1. Utilizza il flag autocheckpoint-enabled durante la creazione della VM TPU o della risorsa in coda.

    Ad esempio:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    
  2. Installa Pax su una singola sezione

    La funzione Autocheckpoint funziona su versioni Pax >= 1.1.0. Sulle VM TPU, installa jax[tpu] e la versione più recente di paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Avvia l'addestramento con la configurazione appropriata

    L'esempio seguente mostra come configurare il modello LmCloudSpmd2B per salvare i checkpoint attivati da Autocheckpoint in un bucket Google Cloud Storage:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
    

    Osserva i due flag passati al comando:

    • jax_fully_async_checkpoint: Quando questo flag sarà attivato, verrà utilizzato orbax.checkpoint.AsyncCheckpointer. La classe AsyncCheckpointer salva automaticamente un checkpoint quando lo script di addestramento riceve un indicatore SIGTERM.
    • exit_after_ondemand_checkpoint: Quando questo flag è attivato, il processo TPU si chiude dopo il salvataggio del checkpoint automatico, attivando così l'esecuzione immediata della manutenzione. Se non utilizzi questo flag, l'addestramento continuerà dopo il salvataggio del checkpoint e Cloud TPU attenderà il verificarsi di un timeout (5 minuti) prima di eseguire la manutenzione richiesta.

Guida rapida: Autocheckpoint con Pax su più sezioni

Il checkpoint automatico funziona non solo per le sezioni singole, ma anche per Multislice. Questa sezione dettagli i passaggi necessari per utilizzare Autocheckpoint con Multislice.

  1. Specifica il checkpoint automatico durante la creazione delle risorse in coda.

    Il provisioning di un ambiente con più sezioni può essere eseguito solo tramite una richiesta di risorse in coda. Come per la richiesta con sezione singola, utilizza il flag autocheckpoint-enabled nella chiamata per creare una risorsa in coda.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    

    Consulta la Guida dell'utente di Multislice per i dettagli su tutte le opzioni disponibili. Una volta creata la richiesta di risorse in coda e nello stato ACTIVE, segui i passaggi successivi per eseguire Pax con Autocheckpoint.

  2. Installa Pax su tutte le VM nell'ambiente Multislice.

    Sulle VM TPU, installa jax[tpu] e l'ultima versione paxml su tutte le VM TPU nel tuo ambiente Multislice:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Avvia l'addestramento con la configurazione appropriata

    Questo esempio mostra come configurare il modello LmCloudSpmd2B per Autocheckpoint durante l'addestramento in un ambiente con più sezioni. Prima di eseguire lo script di addestramento, imposta DCN_MESH_SHAPE su [2, 1, 1] come mostrato nel seguente codice:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]
    

    All'avvio dell'addestramento, oltre ai flag della riga di comando trattati nel caso di una singola sezione, sono necessari altri tre flag:

    • num_hosts: il numero totale di host. In questo caso, è 2.
    • host_index: l'indice dell'host che avvia l'addestramento. Varia da 0 a N-1, dove N è il numero totale di host.
    • server_addr: l'indirizzo IP del worker 0 del nodo 0, con una porta inutilizzata (ad esempio, 8476). Per trovare queste informazioni, utilizza hostname -i sul worker 0 del nodo 0.

Checkpoint automatico con Orbax

La funzionalità Autocheckpoint non è limitata a MaxText o Pax. Qualsiasi framework in grado di acquisire l'indicatore SIGTERM e avviare un processo di checkpoint funziona con l'infrastruttura fornita da Autocheckpoint. Orbax, uno spazio dei nomi che fornisce librerie di utilità comuni per gli utenti JAX, offre queste funzionalità.

Come spiegato nella documentazione di Orbax, queste funzionalità sono abilitate per impostazione predefinita per gli utenti di orbax.checkpoint.CheckpointManager. Il metodo save chiamato dopo ogni passaggio controlla automaticamente se un evento di manutenzione è imminente e, in questo caso, salva un checkpoint anche se il numero del passaggio non è un multiplo di save_interval_steps. La documentazione GitHub illustra anche come eseguire l'uscita dall'addestramento dopo aver salvato un checkpoint automatico, con una modifica nel codice utente.