Checkpoint automático do Cloud TPU [pré-lançamento público]

Visão geral

Historicamente, quando uma VM da TPU precisa de manutenção, o procedimento é iniciado imediatamente, sem que os usuários tenham tempo de realizar ações que preservam o progresso, como salvar um checkpoint. Isso é mostrado na Figura 1(a).

checkpoint automático

Figura 1. Ilustração do recurso de checkpoint: (a) Sem o checkpoint, o progresso do treinamento do último checkpoint é perdido quando há um evento de manutenção próximo. (b) Com o checkpoint, o progresso do treinamento desde o último checkpoint pode ser preservado quando há um evento de manutenção próximo.

Use o Autocheckpoint (Figura 1(b)) para preservar o progresso do treinamento configurando seu código para salvar um checkpoint não programado quando ocorrer um evento de manutenção. Quando ocorre um evento de manutenção, o progresso desde o último checkpoint é salvo automaticamente. O recurso funciona em frações únicas e em multislices.

O recurso "Autocheckpoint" funciona com frameworks que podem capturar o SIGTERM e, depois, salvar um checkpoint. Os frameworks compatíveis incluem MaxText, Pax e JAX com Orbax. O suporte para estruturas adicionais será anunciado conforme forem disponibilizados.

Por enquanto, apenas TPUs (v2-v4 e v5e) criadas por meio da API Cloud TPU podem usar esse recurso. O suporte para TPUs no GKE será anunciado quando estiver disponível.

Como usar o checkpoint automático

A funcionalidade do checkpoint automático está desativada por padrão. Ao criar uma TPU ou um recurso na fila, é possível ativá-lo adicionando a sinalização --autocheckpoint-enabled ao provisionar a TPU. Com o recurso ativado, o Cloud TPU executa as seguintes etapas ao receber uma notificação de um evento de manutenção:

  1. Capture o SIGTERM enviado ao processo usando o dispositivo TPU.
  2. Espera até que o processo seja encerrado ou cinco minutos tenham se passado, o que acontecer primeiro, e realiza a manutenção nas frações afetadas.

Observe que a infraestrutura usada pelo Autocheckpoint é independente do framework de ML. Qualquer framework de ML oferece suporte ao checkpoint automático, desde que ele possa capturar o sinal SIGTERM e iniciar um processo de checkpoint.

No código do aplicativo, você precisa ativar os recursos do ponto de verificação automático fornecidos pelo framework de ML. No Pax, por exemplo, isso significa ativar as sinalizações de linha de comando ao iniciar o treinamento. Consulte o Guia de início rápido do checkpoint automático com o Pax. Em segundo plano, os frameworks salvam um checkpoint não programado quando um SIGTERM é recebido, e a VM da TPU afetada passa por manutenção quando a TPU não está mais em uso.

Guia de início rápido: checkpoint automático com MaxText

O MaxText é um LLM bem testado, arbitrariamente escalonável, de alto desempenho e de código aberto, escrito em Python/JAX puro voltado para Cloud TPUs. MaxText contém toda a configuração necessária para usar o recurso Autocheckpoint.

O README do MaxText descreve duas maneiras de executar o MaxText em escala:

Ao usar multihost_runner.py, a única alteração necessária é definir a sinalização autocheckpoint-enabled ao provisionar o recurso na fila. Ao usar multihost_job.py, a única mudança necessária é especificar a sinalização de linha de comando ENABLE_AUTOCHECKPOINT=true ao iniciar o job.

Guia de início rápido: checkpoint automático com Pax em frações únicas

Nesta seção, fornecemos um exemplo de como configurar e usar o Autocheckpoint com o Pax em uma única fração. Com a configuração apropriada:

  • Um checkpoint será salvo quando ocorrer um evento de manutenção.
  • O Cloud TPU fará a manutenção nas VMs da TPU afetadas depois que o checkpoint for salvo.
  • Quando a Cloud TPU concluir a manutenção, será possível usar a VM da TPU normalmente.
  1. Use a sinalização autocheckpoint-enabled ao criar a VM da TPU ou o recurso na fila.

    Exemplo:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    
  2. Instalar o Pax em uma única fração

    O recurso Autocheckpoint funciona na versão 1.1.0 ou mais recente do Pax. Nas VMs da TPU, instale jax[tpu] e o paxml mais recente:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Inicie o treinamento com a configuração apropriada

    O exemplo a seguir mostra como configurar o modelo LmCloudSpmd2B para salvar os checkpoints acionados pelo Autocheckpoint em um bucket do Google Cloud Storage:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
    

    Observe as duas sinalizações passadas para o comando:

    • jax_fully_async_checkpoint: com essa sinalização ativada, orbax.checkpoint.AsyncCheckpointer será usado. A classe AsyncCheckpointer salva automaticamente um checkpoint quando o script de treinamento recebe um sinal SIGTERM.
    • exit_after_ondemand_checkpoint: com essa sinalização ativada, os processos da TPU são encerrados depois que o checkpoint automático é salvo, o que aciona a manutenção a ser executada imediatamente. Se você não usar essa sinalização, o treinamento continuará depois que o checkpoint for salvo, e o Cloud TPU aguardará o tempo limite (cinco minutos) antes de realizar a manutenção necessária.

Guia de início rápido: checkpoint automático com Pax no Multislice

O checkpoint automático funciona não apenas para frações únicas, mas também para Multislice. Esta seção detalha as etapas necessárias para usar o Autocheckpoint com o Multislice.

  1. Especifique o Autocheckpoint durante a criação de recursos na fila.

    Um ambiente Multislice só pode ser provisionado por meio de uma solicitação de recurso na fila. Semelhante ao caso de fração única, use a sinalização autocheckpoint-enabled na chamada para criar um recurso na fila.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    

    Consulte o Guia do usuário do Multislice para ver detalhes sobre todas as opções disponíveis. Depois que a solicitação de recurso na fila for criada e estiver no estado ACTIVE, siga as próximas etapas para executar o Pax com o Autocheckpoint.

  2. Instale o Pax em todas as VMs no ambiente Multislice.

    Nas VMs da TPU, instale jax[tpu] e o paxml mais recente em todas as VMs da TPU no ambiente do Multislice:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. Inicie o treinamento com a configuração apropriada

    Neste exemplo, mostramos como configurar o modelo LmCloudSpmd2B para checkpoint automático ao treinar em um ambiente multislice. Antes de executar o script de treinamento, defina DCN_MESH_SHAPE como [2, 1, 1], conforme mostrado no seguinte código:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]
    

    Ao iniciar o treinamento, além das sinalizações de linha de comando discutidas no caso de fração única, são necessárias mais três:

    • num_hosts: o número total de hosts. Nesse caso, é 2.
    • host_index: o índice do host que inicia o treinamento. Ele varia de 0 a N-1, em que N é o número total de hosts.
    • server_addr: o endereço IP do worker 0 do nó 0, com uma porta não utilizada (por exemplo, 8476). Para encontrar essas informações, use hostname -i no worker 0 do nó 0.

Checkpoint automático com o Orbax

O recurso de checkpoint automático não está limitado a MaxText ou Pax. Qualquer framework que possa capturar o sinal SIGTERM e iniciar um processo de checkpoint funciona com a infraestrutura fornecida pelo Autocheckpoint. O Orbax, um namespace que fornece bibliotecas de utilitários comuns para usuários do JAX, oferece esses recursos.

Conforme explicado na documentação da Orbax, esses recursos são ativados por padrão para usuários do orbax.checkpoint.CheckpointManager. O método save, chamado após cada etapa, verifica automaticamente se um evento de manutenção é iminente e, em caso afirmativo, salva um checkpoint, mesmo que o número da etapa não seja um múltiplo de save_interval_steps. A documentação do GitHub também ilustra como fazer o treinamento sair depois de salvar um checkpoint automático, com uma modificação no código do usuário.