Checkpoint automatico Cloud TPU [Anteprima pubblica]
Panoramica
In passato, quando una VM TPU richiedeva manutenzione, la procedura veniva avviata immediatamente, senza lasciare tempo agli utenti per eseguire azioni che preservano i progressi, come il salvataggio di un checkpoint. Questo è come mostrato nella Figura 1(a).
Fig. 1. Illustrazione della funzionalità Autocheckpoint: (a) senza Autocheckpoint, l'avanzamento dell'addestramento dall'ultimo checkpoint viene perso quando è imminente un evento di manutenzione. (b) Con il controllo automatico, i progressi dell'addestramento dall'ultimo controllo possono essere preservati in caso di un evento di manutenzione imminente.
Puoi utilizzare il controllo automatico (Figura 1(b)) per preservare i progressi dell'addestramento configurando il codice in modo da salvare un controllo non pianificato quando si verifica un evento di manutenzione. Quando si verifica un evento di manutenzione, i progressi dall'ultimo checkpoint vengono salvati automaticamente. La funzionalità è compatibile con entrambe le sezioni singole e Multislice.
La funzionalità Autocheckpoint funziona con i framework che possono acquisire SIGTERM e successivamente salvare un checkpoint. I framework supportati includono MaxText, Pax, e JAX con Orbax. Il supporto di altri framework verrà annunciato non appena saranno disponibili.
Al momento solo le TPU (v2-v4 e v5e) create tramite l'API Cloud TPU possono utilizzare questa funzionalità. Il supporto delle TPU in GKE verrà annunciato non appena sarà disponibile.
Utilizzo di Controllo automatico
La funzionalità del checkpoint automatico è disattivata per impostazione predefinita. Quando crei un
una TPU o una risorsa in coda,
puoi abilitarlo aggiungendo il flag --autocheckpoint-enabled
durante il provisioning
la TPU.
Con la funzionalità abilitata, Cloud TPU
esegue la seguente procedura quando riceve la notifica di un
evento di manutenzione:
- Acquisisci SIGTERM inviato al processo utilizzando il dispositivo TPU.
- Attende la chiusura del processo o la scadenza di 5 minuti, a seconda dell'evento la priorità ed esegue la manutenzione delle sezioni interessate.
Tieni presente che l'infrastruttura utilizzata da Autocheckpoint è indipendente dal framework ML. Qualsiasi framework ML può supportare il controllo automatico, a condizione che possa acquisire l'indicatore SIGTERM e avviare un processo di controllo.
Nel codice dell'applicazione, devi attivare le funzionalità di controllo automatico fornite dal framework ML. In Pax, ad esempio, questo significa attivare i flag della riga di comando al momento dell'avvio della formazione (consulta la guida rapida all'autocheckpoint con Pax). Dietro le quinte, i framework salvano un checkpoint non pianificato quando viene ricevuto un SIGTERM e la VM TPU interessata viene sottoposta a manutenzione quando la TPU non è più in uso.
Guida rapida: controllo automatico con MaxText
MaxText è un "LLM open source, ben testato, ad alte prestazioni e arbitrariamente scalabile scritto in puro Python/JAX e rivolto alle Cloud TPU". MaxText contiene tutta la configurazione necessaria per utilizzare la funzionalità Punto di controllo automatico.
Il file README di MaxText descrive due modi per eseguire MaxText su larga scala:
- Utilizzo di
multihost_runner.py
, consigliato per la sperimentazione - Utilizzo di
multihost_job.job
, consigliato per la produzione
Quando utilizzi multihost_runner.py
, l'unica modifica richiesta è impostare il flag autocheckpoint-enabled
durante il provisioning della risorsa in coda. Quando si utilizza
multihost_job.py
, l'unica modifica necessaria è specificare il
ENABLE_AUTOCHECKPOINT=true
all'avvio del job.
Guida rapida: controllo automatico con Pax su singoli slice
In questa sezione viene fornito un esempio di come configurare e utilizzare il controllo automatico con Pax su un singolo slice. Con la configurazione appropriata:
- Verrà salvato un checkpoint quando si verifica un evento di manutenzione.
- Cloud TPU eseguirà la manutenzione delle VM TPU interessate dopo la salvataggio del checkpoint.
- Al termine della manutenzione di Cloud TPU, puoi utilizzare la VM TPU come di consueto.
Utilizza il flag
autocheckpoint-enabled
quando crei la VM TPU o la risorsa in coda.Ad esempio:
PROJECT=your-gcp-project-name ZONE=zone-you-want-to-use NODE_ID=your-node-id ACCELERATOR_TYPE=your-accelerator-type gcloud config set project $PROJECT gcloud config set compute/zone $ZONE
gcloud alpha compute tpus tpu-vm create $NODE_ID \ --accelerator-type $ACCELERATOR_TYPE \ --version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Installare Pax su un singolo slice
La funzione Autocheckpoint funziona su versioni Pax >= 1.1.0. Sulle VM TPU, installa
jax[tpu]
e la versione più recente dipaxml
:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Avvia la formazione con la configurazione appropriata
L'esempio seguente mostra come configurare il modello
LmCloudSpmd2B
per salvare i checkpoint attivati da Autocheckpoint in un bucket Google Cloud Storage:JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Tieni presente i due flag passati al comando:
jax_fully_async_checkpoint
: Quando questo flag sarà attivato, verrà utilizzatoorbax.checkpoint.AsyncCheckpointer
. La classeAsyncCheckpointer
salva automaticamente un checkpoint quando lo script di addestramento riceve un segnale SIGTERM.exit_after_ondemand_checkpoint
: Quando questo flag è attivato, i processi TPU si escono dopo Il checkpoint automatico è stato salvato correttamente e questo attiva la manutenzione da eseguire immediatamente. Se non utilizzi questo l'addestramento continuerà dopo il salvataggio del checkpoint e Cloud TPU attenderà il verificarsi di un timeout (5 minuti) prima di eseguire la manutenzione richiesta.
Guida rapida: controllo automatico con Pax su Multislice
Il controllo automatico funziona non solo per i singoli slice, ma anche per Multislice. Questa sezione descrive in dettaglio i passaggi necessari per utilizzare Autocheckpoint con Multislice.
Specifica il controllo automatico durante la creazione della risorsa in coda.
È possibile eseguire il provisioning di un ambiente Multislice solo tramite una richiesta di risorse in coda. Come per il caso di una singola fetta, utilizza il flag
autocheckpoint-enabled
nella chiamata per creare una risorsa in coda.QR_ID=your-qr-id NODE_COUNT=your-node-count ACCELERATOR_TYPE=your-accelerator-type gcloud compute tpus queued-resources create $QR_ID \ --node-count $NODE_COUNT \ --accelerator-type $ACCELERATOR_TYPE \ --runtime-version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Per informazioni dettagliate su tutte le opzioni disponibili, consulta la Guida dell'utente di Multislice. Dopo che la risorsa in coda viene creata la richiesta e nello stato
ACTIVE
, segui i passaggi successivi per eseguire Pax Checkpoint automatico.Installa Pax su tutte le VM nell'ambiente Multislice.
Sulle VM TPU, installa
jax[tpu]
e la versione più recente dipaxml
su tutte le VM TPU nel tuo ambiente Multislice:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Avvia la formazione con la configurazione appropriata
Questo esempio mostra come configurare il modello
LmCloudSpmd2B
per il controllo automatico durante l'addestramento in un ambiente Multislice. Prima del giorno eseguendo lo script di addestramento, imposta DCN_MESH_SHAPE su [2, 1, 1] come mostrato il seguente codice:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 4, 1] DCN_MESH_SHAPE = [2, 1, 1]
All'avvio dell'addestramento, oltre ai flag della riga di comando descritti nel caso di una singola sezione, ne occorrono altre tre:
num_hosts
: il numero totale di host. In questo caso, è 2.host_index
: l'indice dell'host che avvia l'addestramento. Varia da 0 aN-1
, doveN
è il numero totale di host.server_addr
: l'indirizzo IP del worker 0 del nodo 0, con un (ad esempio, 8476). Per trovare queste informazioni, utilizzahostname -i
sul worker 0 del nodo 0.
Checkpoint automatico con Orbax
La funzionalità di controllo automatico non è limitata a MaxText o Pax. Qualsiasi framework in grado di acquisire il segnale SIGTERM e avviare una il processo di checkpointing funziona con l'infrastruttura fornita da Autocheckpoint. Orbax, uno spazio dei nomi che fornisce librerie di utilità comuni per gli utenti di JAX, offre queste funzionalità.
Come spiegato nella documentazione Orbax,
queste funzionalità sono attive per impostazione predefinita
di orbax.checkpoint.CheckpointManager
. Il metodo save
che viene chiamato dopo ogni passaggio verifica automaticamente se
è imminente e, in questo caso, salva un checkpoint anche se il numero del passaggio
non è un multiplo di save_interval_steps
.
La documentazione di GitHub illustra anche come far uscire l'addestramento dopo aver salvato un checkpoint automatico, con una modifica nel codice utente.