Checkpoint automatico di Cloud TPU [anteprima pubblica]
Panoramica
Storicamente, quando una VM TPU richiede la manutenzione, la procedura viene avviata immediatamente, senza lasciare tempo agli utenti per eseguire azioni che garantiscono l'avanzamento, come il salvataggio di un checkpoint. Ciò è mostrato nella Figura 1(a).
Fig. 1. Illustrazione della funzionalità Autocheckpoint: (a) Senza Autocheckpoint, l'avanzamento dell'addestramento dall'ultimo checkpoint viene perso quando si verifica un evento di manutenzione imminente. (b) Con il checkpoint automatico, l'avanzamento dell'addestramento dall'ultimo checkpoint può essere conservato in caso di un evento di manutenzione imminente.
Puoi utilizzare il punto di controllo automatico (Figura 1(b)) per preservare l'avanzamento dell'addestramento configurando il codice per salvare un checkpoint non pianificato quando si verifica un evento di manutenzione. Quando si verifica un evento di manutenzione, l'avanzamento dall'ultimo checkpoint viene salvato automaticamente. Questa funzionalità può essere usata sia sia su una sezione singola che su più sezioni.
La funzionalità Autocheckpoint funziona con framework in grado di acquisire SIGTERM e successivamente salvare un checkpoint. I framework supportati includono MaxText, Pax e JAX con Orbax. Il supporto per framework aggiuntivi verrà annunciato non appena saranno disponibili.
Al momento solo le TPU (v2-v4 e v5e) create tramite l'API Cloud TPU possono utilizzare questa funzionalità. Il supporto per le TPU in GKE verrà annunciato quando sarà disponibile.
Utilizzo di Autocheckpoint
La funzionalità checkpoint automatico è disattivata per impostazione predefinita. Quando crei una TPU o una risorsa in coda, puoi abilitarla aggiungendo il flag --autocheckpoint-enabled
durante il provisioning della TPU.
Con la funzionalità abilitata, Cloud TPU esegue i seguenti passaggi quando riceve la notifica di un evento di manutenzione:
- Acquisisci SIGTERM inviato al processo utilizzando il dispositivo TPU,
- Attende fino all'uscita del processo o dopo che sono trascorsi 5 minuti, a seconda dell'evento che si verifica per primo, ed esegue la manutenzione delle sezioni interessate.
Tieni presente che l'infrastruttura utilizzata da Autocheckpoint è indipendente dal framework ML. Qualsiasi framework ML può supportare Autocheckpoint a condizione di poter acquisire il segnale SIGTERM e avviare un processo di checkpointing.
Nel codice dell'applicazione, devi abilitare le funzionalità Autocheckpoint fornite dal framework ML. In Pax, ad esempio, ciò significa abilitare i flag della riga di comando all'avvio dell'addestramento (consulta la guida rapida di Autocheckpoint con Pax). Dietro le quinte, i framework salvano un checkpoint non pianificato alla ricezione di un SIGTERM e la VM TPU interessata viene sottoposta a manutenzione quando la TPU non è più in uso.
Guida rapida: checkpoint automatico con MaxText
MaxText è un "LLM open source, ad alte prestazioni, scalabile e ben collaudato, scritto in Python/JAX puro per il targeting di Cloud TPU". MaxText contiene tutte le configurazioni necessarie per utilizzare la funzionalità Autocheckpoint.
Il file README di MaxText descrive due modi per eseguire MaxText su larga scala:
- Utilizzo di
multihost_runner.py
, consigliato per la sperimentazione - Utilizza
multihost_job.job
, consigliato per la produzione
Quando utilizzi multihost_runner.py
, l'unica modifica richiesta è impostare il flag autocheckpoint-enabled
durante il provisioning della risorsa in coda. Quando utilizzi multihost_job.py
, l'unica modifica necessaria è specificare il flag della riga di comando ENABLE_AUTOCHECKPOINT=true
all'avvio del job.
Guida rapida: checkpoint automatico con Pax su sezioni singole
In questa sezione, forniamo un esempio di come configurare e utilizzare Autocheckpoint con Pax su una singola sezione. Con la configurazione appropriata:
- Quando si verifica un evento di manutenzione, verrà salvato un checkpoint.
- Cloud TPU eseguirà la manutenzione delle VM TPU interessate dopo il salvataggio del checkpoint.
- Quando Cloud TPU completa la manutenzione, puoi utilizzare la VM TPU come di consueto.
Utilizza il flag
autocheckpoint-enabled
durante la creazione della VM TPU o della risorsa in coda.Ad esempio:
PROJECT=your-gcp-project-name ZONE=zone-you-want-to-use NODE_ID=your-node-id ACCELERATOR_TYPE=your-accelerator-type gcloud config set project $PROJECT gcloud config set compute/zone $ZONE
gcloud alpha compute tpus tpu-vm create $NODE_ID \ --accelerator-type $ACCELERATOR_TYPE \ --version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Installa Pax su una singola sezione
La funzionalità Autocheckpoint funziona su versioni di Pax successive alla 1.1.0. Sulle VM TPU, installa
jax[tpu]
e la versione più recente dipaxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Avvia la formazione con la configurazione appropriata
L'esempio seguente mostra come configurare il modello
LmCloudSpmd2B
per salvare i checkpoint attivati da Autocheckpoint in un bucket Google Cloud Storage:JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Nota i due flag che vengono passati al comando:
jax_fully_async_checkpoint
: con questo flag attivo, verrà utilizzatoorbax.checkpoint.AsyncCheckpointer
. La classeAsyncCheckpointer
salva automaticamente un checkpoint quando lo script di addestramento riceve un segnale SIGTERM.exit_after_ondemand_checkpoint
: se questo flag è attivo, la TPU elabora le uscite dopo il salvataggio corretto del checkpoint automatico, che attiva l'esecuzione immediata della manutenzione. Se non utilizzi questo flag, l'addestramento continuerà dopo il salvataggio del checkpoint e Cloud TPU attenderà il verificarsi di un timeout (5 minuti) prima di eseguire la manutenzione richiesta.
Guida rapida: checkpoint automatico con Pax su Multislice
Il punto di controllo automatico funziona non solo per le sezioni singole, ma anche per Multisezione. Questa sezione descrive i passaggi necessari per utilizzare Autocheckpoint con Multislice.
Specifica il punto di controllo automatico durante la creazione delle risorse in coda.
È possibile eseguire il provisioning di un ambiente multislice solo tramite una richiesta di risorse in coda. Come per il caso a sezione singola, utilizza il flag
autocheckpoint-enabled
nella chiamata per creare una risorsa in coda.QR_ID=your-qr-id NODE_COUNT=your-node-count ACCELERATOR_TYPE=your-accelerator-type gcloud alpha compute tpus queued-resources create $QR_ID \ --node-count $NODE_COUNT \ --accelerator-type $ACCELERATOR_TYPE \ --runtime-version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Per i dettagli su tutte le opzioni disponibili, consulta la guida dell'utente di Multislice. Dopo aver creato la richiesta di risorse in coda e nello stato
ACTIVE
, segui i passaggi successivi per eseguire Pax con Autocheckpoint.Installare Pax su tutte le VM nell'ambiente Multislice.
Sulle VM TPU, installa
jax[tpu]
e la versione più recente dipaxml
su tutte le VM TPU nel tuo ambiente multislice:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Avvia la formazione con la configurazione appropriata
Questo esempio mostra come configurare il modello
LmCloudSpmd2B
per il checkpoint automatico durante l'addestramento in un ambiente multisezione. Prima di eseguire lo script di addestramento, imposta DCN_MESH_SHAPE su [2, 1, 1] come mostrato nel codice seguente:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 4, 1] DCN_MESH_SHAPE = [2, 1, 1]
Quando avvii l'addestramento, oltre ai flag della riga di comando discussi nel caso single-slice, sono richiesti altri tre:
num_hosts
: il numero totale di host. In questo caso, è 2.host_index
: l'indice dell'host che avvia l'addestramento. Varia da 0 aN-1
, doveN
è il numero totale di host.server_addr
: l'indirizzo IP del worker 0 del nodo 0, con una porta inutilizzata (ad esempio, 8476). Per trovare queste informazioni, utilizzahostname -i
sul worker 0 del nodo 0.
Checkpoint automatico con Orbax
La funzionalità Checkpoint automatico non è limitata a MaxText o Pax. Qualsiasi framework in grado di acquisire il segnale SIGTERM e avviare un processo di checkpoint funziona con l'infrastruttura fornita da Autocheckpoint. Orbax, uno spazio dei nomi che fornisce librerie di utilità comuni per gli utenti JAX, offre queste funzionalità.
Come spiegato nella documentazione di Orbax, queste funzionalità sono attive per impostazione predefinita per gli utenti di orbax.checkpoint.CheckpointManager
. Il metodo save
richiamato dopo ogni passaggio controlla automaticamente se un evento di manutenzione è imminente e, in questo caso, salva un checkpoint anche se il numero di passaggio non è un multiplo di save_interval_steps
.
La documentazione di GitHub illustra anche come uscire dall'addestramento dopo aver salvato un checkpoint automatico, con una modifica al codice utente.