Checkpoint automático da Cloud TPU [pré-lançamento público]
Informações gerais
Historicamente, quando uma VM de TPU requer manutenção, o procedimento é iniciado imediatamente, sem dar tempo para os usuários realizarem ações que preservam o progresso, como salvar um checkpoint. Isso é mostrado na Figura 1(a).
Figura 1. Ilustração do recurso de checkpoint automático: (a) Sem o checkpoint, o progresso do treinamento do último checkpoint é perdido quando há um evento de manutenção futuro. (b) Com o checkpoint automático, o progresso do treinamento desde o último checkpoint pode ser preservado quando há um evento de manutenção futuro.
Use o checkpoint automático (Figura 1(b) para preservar o progresso do treinamento configurando o código para salvar um checkpoint não programado quando ocorrer um evento de manutenção. Quando ocorre um evento de manutenção, o progresso desde o último checkpoint é salvo automaticamente. O recurso funciona em frações únicas e em vários cortes.
O recurso de checkpoint automático funciona com frameworks que podem capturar o SIGTERM e, posteriormente, salvar um checkpoint. Os frameworks com suporte incluem MaxText, Pax e JAX com Orbax. O suporte para estruturas adicionais será anunciado à medida que forem disponibilizados.
Apenas TPUs (v2-v4 e v5e) criadas por meio da API Cloud TPU podem usar esse recurso por enquanto. O suporte para TPUs no GKE será anunciado quando estiver disponível.
Como usar o checkpoint automático
A funcionalidade do ponto de verificação automático fica desativada por padrão. Quando você cria uma TPU ou um recurso em fila, é possível ativá-lo adicionando a sinalização --autocheckpoint-enabled
ao provisionar a TPU.
Com o recurso ativado, o Cloud TPU executa as etapas a seguir depois de receber a notificação de um evento de manutenção:
- Capture o SIGTERM enviado ao processo usando o dispositivo TPU,
- Aguarda até que o processo seja encerrado ou 5 minutos tenham se passado, o que ocorrer primeiro, e executa a manutenção nas frações afetadas.
A infraestrutura usada pelo checkpoint automático não depende do framework de ML. Qualquer framework de ML é compatível com o checkpoint automático, desde que seja possível capturar o sinal SIGTERM e iniciar um processo de checkpoint.
No código do aplicativo, é necessário ativar os recursos do ponto de verificação automático fornecidos pelo framework de ML. No Pax, por exemplo, isso significa ativar sinalizações de linha de comando ao iniciar o treinamento. Consulte o Guia de início rápido do checkpoint automático com o Pax. Nos bastidores, os frameworks salvam um checkpoint não programado quando um SIGTERM é recebido, e a VM da TPU afetada passa por manutenção quando a TPU não está mais em uso.
Guia de início rápido: checkpoint automático com MaxText
MaxText é um "LLM de alto desempenho, arbitrariamente escalonável, de código aberto e bem testado, escrito em Cloud TPUs puras do Python/JAX". MaxText contém toda a configuração necessária para usar o recurso Autocheckpoint.
O README do MaxText descreve duas maneiras de executar o MaxText em escala:
- Usando
multihost_runner.py
, recomendado para experimentos - Usando
multihost_job.job
, recomendado para produção
Ao usar multihost_runner.py
, a única alteração necessária
é definir a sinalização autocheckpoint-enabled
ao provisionar
o recurso na fila. Ao usar
multihost_job.py
, a única mudança necessária é especificar a
sinalização de linha de comando ENABLE_AUTOCHECKPOINT=true
ao iniciar o job.
Guia de início rápido: checkpoint automático com o Pax em fatias únicas
Nesta seção, fornecemos um exemplo de como configurar e usar o checkpoint automático com o Pax em uma única fração. Com a configuração adequada:
- Um checkpoint será salvo quando ocorrer um evento de manutenção.
- O Cloud TPU realizará a manutenção nas VMs da TPU afetadas depois que o checkpoint for salvo.
- Quando o Cloud TPU concluir a manutenção, use a VM da TPU normalmente.
Use a sinalização
autocheckpoint-enabled
ao criar a VM da TPU ou o recurso na fila.Exemplo:
PROJECT=your-gcp-project-name ZONE=zone-you-want-to-use NODE_ID=your-node-id ACCELERATOR_TYPE=your-accelerator-type gcloud config set project $PROJECT gcloud config set compute/zone $ZONE
gcloud alpha compute tpus tpu-vm create $NODE_ID \ --accelerator-type $ACCELERATOR_TYPE \ --version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Instalar o Pax em uma única fração
O recurso Autocheckpoint funciona nas versões do Pax 1.1.0 ou posteriores. Nas VMs da TPU, instale
jax[tpu]
e ospaxml
mais recentes:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Iniciar o treinamento com a configuração adequada
O exemplo a seguir mostra como configurar o modelo
LmCloudSpmd2B
para salvar checkpoints acionados pelo checkpoint em um bucket do Google Cloud Storage:JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Observe as duas sinalizações que são transmitidas ao comando:
jax_fully_async_checkpoint
: com essa flag ativada,orbax.checkpoint.AsyncCheckpointer
será usado. A classeAsyncCheckpointer
salva automaticamente um checkpoint quando o script de treinamento recebe um sinal SIGTERM.exit_after_ondemand_checkpoint
: com essa sinalização ativada, o processo da TPU é encerrado depois que o checkpoint automático é salvo, o que aciona a manutenção a ser realizada imediatamente. Se você não usar essa sinalização, o treinamento continuará depois que o checkpoint for salvo, e o Cloud TPU aguardará um tempo limite (cinco minutos) antes de executar a manutenção necessária.
Guia de início rápido: checkpoint automático com Pax no Multislice
O checkpoint automático não funciona apenas para frações únicas, mas também para Multislice. Nesta seção, detalhamos as etapas necessárias para usar o checkpoint automático com o recurso Multislice.
Especificar o checkpoint automático durante a criação de recursos na fila.
Um ambiente Multislice só pode ser provisionado por meio de uma solicitação de recurso em fila. De forma semelhante ao caso de fatia única, use a sinalização
autocheckpoint-enabled
na chamada para criar um recurso na fila.QR_ID=your-qr-id NODE_COUNT=your-node-count ACCELERATOR_TYPE=your-accelerator-type gcloud alpha compute tpus queued-resources create $QR_ID \ --node-count $NODE_COUNT \ --accelerator-type $ACCELERATOR_TYPE \ --runtime-version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Consulte o Guia do usuário de multislice para ver detalhes sobre todas as opções disponíveis. Depois que a solicitação de recurso na fila for criada e no estado
ACTIVE
, siga as próximas etapas para executar o Pax com o Checkpoint automático.Instalar o Pax em todas as VMs no ambiente Multislice.
Nas VMs de TPU, instale
jax[tpu]
e opaxml
mais recente em todas as VMs de TPU no ambiente Multislice:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Iniciar o treinamento com a configuração adequada
Neste exemplo, mostramos como configurar o modelo
LmCloudSpmd2B
para o checkpoint automático ao realizar o treinamento em um ambiente multislice. Antes de executar o script de treinamento, defina DCN_MESH_SHAPE como [2, 1, 1], conforme mostrado no código a seguir:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 4, 1] DCN_MESH_SHAPE = [2, 1, 1]
Ao iniciar o treinamento, além das sinalizações de linha de comando discutidas no caso de fatia única, são necessárias mais três:
num_hosts
: o número total de hosts. Neste caso, é 2.host_index
: o índice do host que está iniciando o treinamento. Ela varia de 0 aN-1
, em queN
é o número total de hosts.server_addr
: o endereço IP do worker 0 do nó 0, com uma porta não utilizada (por exemplo, 8476). Para encontrar essas informações, usehostname -i
no worker 0 do nó 0.
Checkpoint automático com o Orbax
O recurso "Autocheckpoint" não está limitado aos modelos MaxText ou Pax. Qualquer framework que possa capturar o sinal SIGTERM e iniciar um processo de checkpoint funciona com a infraestrutura fornecida pelo Autocheckpoint. O Orbax, um namespace que fornece bibliotecas de utilitários comuns para usuários do JAX, oferece esses recursos.
Conforme explicado na documentação da Orbax,
esses recursos são ativados por padrão para usuários
de orbax.checkpoint.CheckpointManager
. O método save
chamado após cada etapa verifica automaticamente se um evento de
manutenção está iminente e, em caso afirmativo, salva um checkpoint, mesmo que o número da etapa
não seja um múltiplo de save_interval_steps
.
A documentação do GitHub
também mostra como sair do treinamento depois de salvar um
checkpoint automático, com uma modificação no código do usuário.