Checkpoint automático da Cloud TPU [pré-lançamento público]

Informações gerais

Historicamente, quando uma VM de TPU requer manutenção, o procedimento é iniciado imediatamente, sem dar tempo para os usuários realizarem ações que preservam o progresso, como salvar um checkpoint. Isso é mostrado na Figura 1(a).

checkpoint automático

Figura 1. Ilustração do recurso de checkpoint automático: (a) Sem o checkpoint, o progresso do treinamento do último checkpoint é perdido quando há um evento de manutenção futuro. (b) Com o checkpoint automático, o progresso do treinamento desde o último checkpoint pode ser preservado quando há um evento de manutenção futuro.

Use o checkpoint automático (Figura 1(b) para preservar o progresso do treinamento configurando o código para salvar um checkpoint não programado quando ocorrer um evento de manutenção. Quando ocorre um evento de manutenção, o progresso desde o último checkpoint é salvo automaticamente. O recurso funciona em frações únicas e em vários cortes.

O recurso de checkpoint automático funciona com frameworks que podem capturar o SIGTERM e, posteriormente, salvar um checkpoint. Os frameworks com suporte incluem MaxText, Pax e JAX com Orbax. O suporte para estruturas adicionais será anunciado à medida que forem disponibilizados.

Apenas TPUs (v2-v4 e v5e) criadas por meio da API Cloud TPU podem usar esse recurso por enquanto. O suporte para TPUs no GKE será anunciado quando estiver disponível.

Como usar o checkpoint automático

A funcionalidade do ponto de verificação automático fica desativada por padrão. Quando você cria uma TPU ou um recurso em fila, é possível ativá-lo adicionando a sinalização --autocheckpoint-enabled ao provisionar a TPU. Com o recurso ativado, o Cloud TPU executa as etapas a seguir depois de receber a notificação de um evento de manutenção:

  1. Capture o SIGTERM enviado ao processo usando o dispositivo TPU,
  2. Aguarda até que o processo seja encerrado ou 5 minutos tenham se passado, o que ocorrer primeiro, e executa a manutenção nas frações afetadas.

A infraestrutura usada pelo checkpoint automático não depende do framework de ML. Qualquer framework de ML é compatível com o checkpoint automático, desde que seja possível capturar o sinal SIGTERM e iniciar um processo de checkpoint.

No código do aplicativo, é necessário ativar os recursos do ponto de verificação automático fornecidos pelo framework de ML. No Pax, por exemplo, isso significa ativar sinalizações de linha de comando ao iniciar o treinamento. Consulte o Guia de início rápido do checkpoint automático com o Pax. Nos bastidores, os frameworks salvam um checkpoint não programado quando um SIGTERM é recebido, e a VM da TPU afetada passa por manutenção quando a TPU não está mais em uso.

Guia de início rápido: checkpoint automático com MaxText

MaxText é um "LLM de alto desempenho, arbitrariamente escalonável, de código aberto e bem testado, escrito em Cloud TPUs puras do Python/JAX". MaxText contém toda a configuração necessária para usar o recurso Autocheckpoint.

O README do MaxText descreve duas maneiras de executar o MaxText em escala:

Ao usar multihost_runner.py, a única alteração necessária é definir a sinalização autocheckpoint-enabled ao provisionar o recurso na fila. Ao usar multihost_job.py, a única mudança necessária é especificar a sinalização de linha de comando ENABLE_AUTOCHECKPOINT=true ao iniciar o job.

Guia de início rápido: checkpoint automático com o Pax em fatias únicas

Nesta seção, fornecemos um exemplo de como configurar e usar o checkpoint automático com o Pax em uma única fração. Com a configuração adequada:

  • Um checkpoint será salvo quando ocorrer um evento de manutenção.
  • O Cloud TPU realizará a manutenção nas VMs da TPU afetadas depois que o checkpoint for salvo.
  • Quando o Cloud TPU concluir a manutenção, use a VM da TPU normalmente.
  1. Use a sinalização autocheckpoint-enabled ao criar a VM da TPU ou o recurso na fila.

    Exemplo:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    
  2. Instalar o Pax em uma única fração

    O recurso Autocheckpoint funciona nas versões do Pax 1.1.0 ou posteriores. Nas VMs da TPU, instale jax[tpu] e os paxml mais recentes:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Iniciar o treinamento com a configuração adequada

    O exemplo a seguir mostra como configurar o modelo LmCloudSpmd2B para salvar checkpoints acionados pelo checkpoint em um bucket do Google Cloud Storage:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
    

    Observe as duas sinalizações que são transmitidas ao comando:

    • jax_fully_async_checkpoint: com essa flag ativada, orbax.checkpoint.AsyncCheckpointer será usado. A classe AsyncCheckpointer salva automaticamente um checkpoint quando o script de treinamento recebe um sinal SIGTERM.
    • exit_after_ondemand_checkpoint: com essa sinalização ativada, o processo da TPU é encerrado depois que o checkpoint automático é salvo, o que aciona a manutenção a ser realizada imediatamente. Se você não usar essa sinalização, o treinamento continuará depois que o checkpoint for salvo, e o Cloud TPU aguardará um tempo limite (cinco minutos) antes de executar a manutenção necessária.

Guia de início rápido: checkpoint automático com Pax no Multislice

O checkpoint automático não funciona apenas para frações únicas, mas também para Multislice. Nesta seção, detalhamos as etapas necessárias para usar o checkpoint automático com o recurso Multislice.

  1. Especificar o checkpoint automático durante a criação de recursos na fila.

    Um ambiente Multislice só pode ser provisionado por meio de uma solicitação de recurso em fila. De forma semelhante ao caso de fatia única, use a sinalização autocheckpoint-enabled na chamada para criar um recurso na fila.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud alpha compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    

    Consulte o Guia do usuário de multislice para ver detalhes sobre todas as opções disponíveis. Depois que a solicitação de recurso na fila for criada e no estado ACTIVE, siga as próximas etapas para executar o Pax com o Checkpoint automático.

  2. Instalar o Pax em todas as VMs no ambiente Multislice.

    Nas VMs de TPU, instale jax[tpu] e o paxml mais recente em todas as VMs de TPU no ambiente Multislice:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Iniciar o treinamento com a configuração adequada

    Neste exemplo, mostramos como configurar o modelo LmCloudSpmd2B para o checkpoint automático ao realizar o treinamento em um ambiente multislice. Antes de executar o script de treinamento, defina DCN_MESH_SHAPE como [2, 1, 1], conforme mostrado no código a seguir:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]
    

    Ao iniciar o treinamento, além das sinalizações de linha de comando discutidas no caso de fatia única, são necessárias mais três:

    • num_hosts: o número total de hosts. Neste caso, é 2.
    • host_index: o índice do host que está iniciando o treinamento. Ela varia de 0 a N-1, em que N é o número total de hosts.
    • server_addr: o endereço IP do worker 0 do nó 0, com uma porta não utilizada (por exemplo, 8476). Para encontrar essas informações, use hostname -i no worker 0 do nó 0.

Checkpoint automático com o Orbax

O recurso "Autocheckpoint" não está limitado aos modelos MaxText ou Pax. Qualquer framework que possa capturar o sinal SIGTERM e iniciar um processo de checkpoint funciona com a infraestrutura fornecida pelo Autocheckpoint. O Orbax, um namespace que fornece bibliotecas de utilitários comuns para usuários do JAX, oferece esses recursos.

Conforme explicado na documentação da Orbax, esses recursos são ativados por padrão para usuários de orbax.checkpoint.CheckpointManager. O método save chamado após cada etapa verifica automaticamente se um evento de manutenção está iminente e, em caso afirmativo, salva um checkpoint, mesmo que o número da etapa não seja um múltiplo de save_interval_steps. A documentação do GitHub também mostra como sair do treinamento depois de salvar um checkpoint automático, com uma modificação no código do usuário.