Controllo automatico Cloud TPU [anteprima pubblica]

Panoramica

In passato, quando una VM TPU richiedeva manutenzione, la procedura veniva avviata immediatamente, senza lasciare tempo agli utenti di eseguire azioni che preservano i progressi, come il salvataggio di un checkpoint. Come mostrato nella Figura 1(a).

autocheckpoint

Figura 1. Illustrazione della funzionalità Autocheckpoint: (a) senza Autocheckpoint, l'avanzamento dell'addestramento dall'ultimo checkpoint viene perso quando è imminente un evento di manutenzione. (b) Con il controllo automatico, i progressi dell'addestramento dall'ultimo controllo possono essere preservati in caso di un evento di manutenzione imminente.

Puoi utilizzare il controllo automatico (Figura 1(b)) per preservare i progressi dell'addestramento configurando il codice in modo da salvare un controllo non pianificato quando si verifica un evento di manutenzione. Quando si verifica un evento di manutenzione, i progressi dall'ultimo checkpoint vengono salvati automaticamente. La funzionalità è supportata sia per le singole sezioni sia per le sezioni multiple.

La funzionalità Autocheckpoint funziona con i framework che possono acquisire SIGTERM e successivamente salvare un checkpoint. I framework supportati includono MaxText, Pax, e JAX con Orbax. Il supporto di altri framework verrà annunciato non appena saranno disponibili.

Al momento solo le TPU (v2-v4 e v5e) create tramite l'API Cloud TPU possono utilizzare questa funzionalità. Il supporto delle TPU in GKE verrà annunciato quando sarà disponibile.

Utilizzo di Controllo automatico

La funzionalità di controllo automatico è disattivata per impostazione predefinita. Quando crei un TPU o una risorsa in coda, puoi attivarla aggiungendo il flag --autocheckpoint-enabled durante il provisioning del TPU. Con la funzionalità attivata, Cloud TPU esegue i seguenti passaggi dopo aver ricevuto la notifica di un evento di manutenzione:

  1. Acquisisci SIGTERM inviato al processo utilizzando il dispositivo TPU.
  2. Attende l'uscita del processo o il trascorrere di 5 minuti, a seconda dell'evento che si verifica per primo, ed esegue la manutenzione delle sezioni interessate.

Tieni presente che l'infrastruttura utilizzata da Autocheckpoint è indipendente dal framework ML. Qualsiasi framework ML può supportare il controllo automatico se è in grado di acquisire l'indicatore SIGTERM e avviare un processo di controllo.

Nel codice dell'applicazione, devi attivare le funzionalità di controllo automatico fornite dal framework ML. In Pax, ad esempio, questo significa attivare i flag della riga di comando al momento dell'avvio della formazione (consulta la guida rapida all'autocheckpoint con Pax). Dietro le quinte, i framework salvano un checkpoint non pianificato quando viene ricevuto un SIGTERM e la VM TPU interessata viene sottoposta a manutenzione quando la TPU non è più in uso.

Guida rapida: controllo automatico con MaxText

MaxText è un "LLM open source, ben testato, ad alte prestazioni e arbitrariamente scalabile scritto in puro Python/JAX e rivolto alle Cloud TPU". MaxText contiene tutta la configurazione necessaria per utilizzare la funzionalità di controllo automatico.

Il file README di MaxText descrive due modi per eseguire MaxText su larga scala:

Quando utilizzi multihost_runner.py, l'unica modifica richiesta è impostare il flag autocheckpoint-enabled durante il provisioning della risorsa in coda. Quando utilizzi multihost_job.py, l'unica modifica richiesta è specificare il ENABLE_AUTOCHECKPOINT=true flag della riga di comando al momento dell'avvio del job.

Guida rapida: controllo automatico con Pax su singoli slice

In questa sezione viene fornito un esempio di come configurare e utilizzare il controllo automatico con Pax su un singolo slice. Con la configurazione appropriata:

  • Un checkpoint verrà salvato quando si verifica un evento di manutenzione.
  • Cloud TPU eseguirà la manutenzione delle VM TPU interessate dopo la salvataggio del checkpoint.
  • Al termine della manutenzione di Cloud TPU, puoi utilizzare la VM TPU come di consueto.
  1. Utilizza il flag autocheckpoint-enabled durante la creazione della VM TPU o della risorsa in coda.

    Ad esempio:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
  2. Installare Pax su un singolo slice

    La funzionalità di controllo automatico funziona sulle versioni di Pax >= 1.1.0. Sulle VM TPU, installa jax[tpu] e la versione più recente di paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. Avvia la formazione con la configurazione appropriata

    L'esempio seguente mostra come configurare il modello LmCloudSpmd2B per salvare i checkpoint attivati da Autocheckpoint in un bucket Google Cloud Storage:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Tieni presente i due flag passati al comando:

    • jax_fully_async_checkpoint: se questo flag è attivo, verrà utilizzato orbax.checkpoint.AsyncCheckpointer. La classe AsyncCheckpointer salva automaticamente un checkpoint quando lo script di addestramento riceve un segnale SIGTERM.
    • exit_after_ondemand_checkpoint: se questo flag è attivo, i processi TPU escono dopo il salvataggio corretto del checkpoint automatico, attivando l'esecuzione immediata della manutenzione. Se non utilizzi questo flag, l'addestramento continuerà dopo il salvataggio del checkpoint e Cloud TPU attenderà che si verifichi un timeout (5 minuti) prima di eseguire la manutenzione richiesta.

Guida rapida: controllo automatico con Pax su Multislice

Il controllo automatico funziona non solo per i singoli slice, ma anche per Multislice. Questa sezione descrive nel dettaglio i passaggi necessari per utilizzare il controllo automatico con Multislice.

  1. Specifica il controllo automatico durante la creazione della risorsa in coda.

    È possibile eseguire il provisioning di un ambiente Multislice solo tramite una richiesta di risorse in coda. Come per la richiesta con una singola fetta, utilizza il flag autocheckpoint-enabled nella chiamata per creare una risorsa in coda.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled

    Per informazioni dettagliate su tutte le opzioni disponibili, consulta la Guida dell'utente di Multislice. Dopo aver creato la richiesta di risorse in coda e averla impostata sullo stato ACTIVE, segui i passaggi successivi per eseguire Pax con il controllo automatico.

  2. Installa Pax su tutte le VM nell'ambiente Multislice.

    Sulle VM TPU, installa jax[tpu] e la versione più recente di paxml su tutte le VM TPU nel tuo ambiente Multislice:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. Avvia la formazione con la configurazione appropriata

    Questo esempio mostra come configurare il modello LmCloudSpmd2B per il controllo automatico durante l'addestramento in un ambiente Multislice. Prima di eseguire lo script di addestramento, imposta DCN_MESH_SHAPE su [2, 1, 1] come mostrato nel codice seguente:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]

    Quando avvii l'addestramento, oltre ai flag della riga di comando discussi nel caso di una singola sezione, sono necessari altri tre flag:

    • num_hosts: il numero totale di host. In questo caso è 2.
    • host_index: l'indice dell'organizzatore che avvia la formazione. Varia da 0 a N-1, dove N è il numero totale di host.
    • server_addr: l'indirizzo IP del worker 0 del nodo 0, con una porta non utilizzata (ad esempio 8476). Per trovare queste informazioni, utilizza hostname -i sul worker 0 del nodo 0.

Controllo automatico con Orbax

La funzionalità di controllo automatico non è limitata a MaxText o Pax. Qualsiasi framework che può acquisire l'indicatore SIGTERM e avviare un procedura di checkpointing funziona con l'infrastruttura fornita da Autocheckpoint. Orbax, uno spazio dei nomi che fornisce librerie di utilità comuni per gli utenti di JAX, offre queste funzionalità.

Come spiegato nella documentazione di Orbax, queste funzionalità sono abilitate per impostazione predefinita per gli utenti di orbax.checkpoint.CheckpointManager. Il metodo save chiamato dopo ogni passaggio controlla automaticamente se è imminente un evento di manutenzione e, in questo caso, salva un checkpoint anche se il numero del passaggio non è un multiplo di save_interval_steps. La documentazione di GitHub illustra anche come far uscire l'addestramento dopo aver salvato un checkpoint automatico, con una modifica nel codice utente.