Checkpoint automático da Cloud TPU [pré-lançamento público]
Visão geral
Historicamente, quando uma VM TPU requer manutenção, o procedimento é iniciado imediatamente, sem deixar tempo para que os usuários realizem ações de preservação do progresso, como salvar um ponto de verificação. Isso é mostrado na Figura 1(a).
Figura 1. Ilustração do recurso de checkpoint automático: (a) Sem o checkpoint automático, o progresso do treinamento do último checkpoint é perdido quando há um evento de manutenção. (b) Com o ponto de verificação automático, o progresso do treinamento desde o último ponto de verificação pode ser preservado quando há um evento de manutenção.
Você pode usar o Autocheckpoint (Figura 1(b)) para preservar o progresso do treinamento, configurando o código para salvar um ponto de controle não programado quando um evento de manutenção ocorrer. Quando um evento de manutenção ocorre, o progresso desde o último ponto de verificação é salvo automaticamente. O recurso funciona em fatias únicas e em Multislice.
O recurso de checkpoint automático funciona com frameworks que podem capturar SIGTERM e, em seguida, salvar um checkpoint. Os frameworks com suporte incluem MaxText, Pax e JAX com Orbax. O suporte a outros frameworks será anunciado conforme forem disponibilizados.
Por enquanto, apenas TPUs (v2-v4 e v5e) criadas pela API Cloud TPU podem usar esse recurso. O suporte a TPUs no GKE será anunciado quando estiver disponível.
Como usar o checkpoint automático
A funcionalidade de verificação automática fica desativada por padrão. Ao criar um TPU ou um recurso enfileirado, é possível ativar essa opção adicionando a flag --autocheckpoint-enabled
ao provisionar o TPU.
Com o recurso ativado, o Cloud TPU
executa as etapas a seguir quando recebe a notificação de um
evento de manutenção:
- Capture o SIGTERM enviado para o processo usando o dispositivo TPU.
- Aguarda até que o processo seja encerrado ou 5 minutos se passem, o que acontecer primeiro, e realiza a manutenção nas fatias afetadas.
A infraestrutura usada pelo Autocheckpoint é independente do framework de ML. Qualquer framework de ML pode oferecer suporte ao Autocheckpoint, desde que capture o sinal SIGTERM e inicie um processo de checkpoint.
No código do aplicativo, é necessário ativar os recursos de verificação automática fornecidos pelo framework de ML. No Pax, por exemplo, isso significa ativar flags de linha de comando ao iniciar o treinamento. Consulte o Guia de início rápido do autocheckpoint com o Pax. Nos bastidores, os frameworks salvam um ponto de controle não programado quando um SIGTERM é recebido e a VM da TPU afetada passa por manutenção quando a TPU não está mais em uso.
Guia de início rápido: checkpoint automático com MaxText
MaxText é um LLM de "alto desempenho, escalonável de forma arbitrária, de código aberto e bem testado, escrito em Python/JAX puro para Cloud TPUs". O MaxText contém toda a configuração necessária para usar o recurso de verificação automática.
O README do MaxText descreve duas maneiras de executar o MaxText em escala:
- Como usar
multihost_runner.py
, recomendado para experimentos - Como usar
multihost_job.job
, recomendado para produção
Ao usar multihost_runner.py
, a única mudança necessária
é definir a flag autocheckpoint-enabled
ao provisionar
o recurso na fila. Ao usar
multihost_job.py
, a única mudança necessária é especificar a
flag de linha de comando ENABLE_AUTOCHECKPOINT=true
ao iniciar o job.
Guia de início rápido: verificação automática com Pax em fatias únicas
Nesta seção, mostramos um exemplo de como configurar e usar o Autocheckpoint com Pax em uma única fatia. Com a configuração adequada:
- Um ponto de controle será salvo quando um evento de manutenção ocorrer.
- A Cloud TPU vai realizar a manutenção nas VMs afetadas depois que o ponto de verificação for salvo.
- Quando a manutenção da Cloud TPU for concluída, você poderá usar a VM TPU normalmente.
Use a flag
autocheckpoint-enabled
ao criar a VM de TPU ou o recurso em fila.Exemplo:
PROJECT=your-gcp-project-name ZONE=zone-you-want-to-use NODE_ID=your-node-id ACCELERATOR_TYPE=your-accelerator-type gcloud config set project $PROJECT gcloud config set compute/zone $ZONE
gcloud alpha compute tpus tpu-vm create $NODE_ID \ --accelerator-type $ACCELERATOR_TYPE \ --version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Instalar o Pax em uma única fatia
O recurso de verificação automática funciona nas versões do Pax a partir da 1.1.0. Nas VMs do TPU, instale
jax[tpu]
e opaxml
mais recente:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Inicie o treinamento com a configuração adequada
O exemplo a seguir mostra como configurar o modelo
LmCloudSpmd2B
para salvar pontos de verificação acionados pelo Autocheckpoint em um bucket do Google Cloud Storage:JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Observe as duas flags transmitidas para o comando:
jax_fully_async_checkpoint
: com essa flag ativada,orbax.checkpoint.AsyncCheckpointer
será usado. A classeAsyncCheckpointer
salva automaticamente um ponto de controle quando o script de treinamento recebe um sinal SIGTERM.exit_after_ondemand_checkpoint
: com essa flag ativada, os processos da TPU são encerrados depois que o checkpoint automático é salvo, o que aciona a manutenção para ser realizada imediatamente. Se você não usar essa flag, o treinamento vai continuar depois que o checkpoint for salvo e o Cloud TPU vai esperar que um tempo limite ocorra (5 minutos) antes de realizar a manutenção necessária.
Guia de início rápido: checkpoint automático com Pax no Multislice
O ponto de verificação automático funciona não apenas para fatias únicas, mas também para várias fatias. Esta seção descreve as etapas necessárias para usar o Autocheckpoint com o Multislice.
Especifique o ponto de verificação automático durante a criação de recursos na fila.
Um ambiente com várias fatias só pode ser provisionado por uma solicitação de recurso em fila. Semelhante ao caso de fatia única, use a flag
autocheckpoint-enabled
na chamada para criar um recurso em fila.QR_ID=your-qr-id NODE_COUNT=your-node-count ACCELERATOR_TYPE=your-accelerator-type gcloud compute tpus queued-resources create $QR_ID \ --node-count $NODE_COUNT \ --accelerator-type $ACCELERATOR_TYPE \ --runtime-version tpu-ubuntu2204-base \ --autocheckpoint-enabled
Consulte o Guia do usuário do Multislice para detalhes sobre todas as opções disponíveis. Depois que a solicitação de recurso em fila for criada e estiver no estado
ACTIVE
, siga as próximas etapas para executar o Pax com o Autocheckpoint.Instale o Pax em todas as VMs no ambiente Multislice.
Nas VMs de TPU, instale o
jax[tpu]
e opaxml
mais recente em todas as VMs de TPU no ambiente de fatias múltiplas:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Inicie o treinamento com a configuração adequada
Este exemplo mostra como configurar o modelo
LmCloudSpmd2B
para o Autocheckpoint ao treinar em um ambiente de várias fatias. Antes de executar o script de treinamento, defina DCN_MESH_SHAPE como [2, 1, 1], conforme mostrado no código abaixo:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 4, 1] DCN_MESH_SHAPE = [2, 1, 1]
Ao iniciar o treinamento, além das flags de linha de comando discutidas no caso de fatia única, três outras são necessárias:
num_hosts
: o número total de hosts. Nesse caso, é 2.host_index
: o índice do host que inicia o treinamento. Ela varia de 0 aN-1
, sendoN
o número total de hosts.server_addr
: o endereço IP do worker 0 do nó 0, com uma porta não utilizada (por exemplo, 8476). Para encontrar essas informações, usehostname -i
no worker 0 do nó 0.
Ponto de verificação automático com Orbax
O recurso de verificação automática não é limitado a MaxText ou Pax. Qualquer framework que possa capturar o sinal SIGTERM e iniciar um processo de verificação funciona com a infraestrutura fornecida pelo Autocheckpoint. O Orbax, um namespace que oferece bibliotecas de utilitários comuns para usuários do JAX, oferece esses recursos.
Conforme explicado na documentação do Orbex,
esses recursos são ativados por padrão para os usuários
do orbax.checkpoint.CheckpointManager
. O método save
que é chamado após cada etapa verifica automaticamente se um evento de manutenção
está iminente e, se estiver, salva um ponto de controle mesmo que o número da etapa
não seja um múltiplo de save_interval_steps
.
A documentação do GitHub (link em inglês)
também ilustra como fazer com que o treinamento saia após salvar um
Autocheckpoint, com uma modificação no código do usuário.