Checkpoint automático do Cloud TPU [pré-lançamento público]

Visão geral

Historicamente, quando uma VM TPU requer manutenção, o procedimento é iniciado imediatamente, sem deixar tempo para que os usuários realizem ações de preservação do progresso, como salvar um ponto de verificação. Isso é mostrado na Figura 1(a).

autocheckpoint

Figura 1. Ilustração do recurso de checkpoint automático: (a) Sem o checkpoint automático, o progresso do treinamento do último checkpoint é perdido quando há um evento de manutenção. (b) Com o ponto de verificação automático, o progresso do treinamento desde o último ponto de verificação pode ser preservado quando há um evento de manutenção.

Você pode usar o Autocheckpoint (Figura 1(b)) para preservar o progresso do treinamento, configurando o código para salvar um ponto de controle não programado quando um evento de manutenção ocorrer. Quando ocorre um evento de manutenção, o progresso desde o último é salvo automaticamente. O recurso funciona em fatias únicas e em Multislice.

O recurso Autocheckpoint funciona com frameworks que podem capturar SIGTERM e, posteriormente, salvar um checkpoint. Os frameworks com suporte incluem MaxText, Pax e JAX com Orbax. O suporte a outras frameworks será anunciado conforme forem disponibilizadas.

Por enquanto, apenas TPUs (v2-v4 e v5e) criados pela API Cloud TPU podem usar esse recurso. O suporte a TPUs no GKE será anunciado quando estiver disponível.

Como usar o checkpoint automático

A funcionalidade de verificação automática fica desativada por padrão. Ao criar um TPU ou um recurso enfileirado, é possível ativar essa opção adicionando a flag --autocheckpoint-enabled ao provisionar o TPU. Com o recurso ativado, o Cloud TPU executa as etapas a seguir quando recebe a notificação de um evento de manutenção:

  1. Capture o SIGTERM enviado ao processo usando o dispositivo TPU.
  2. Aguarda até que o processo seja encerrado ou cinco minutos se passem, o que acontecer primeiro, e realiza a manutenção nas fatias afetadas.

Observe que a infraestrutura usada pelo Autocheckpoint é independente do framework de ML. Qualquer framework de ML pode oferecer suporte ao Autocheckpoint, desde que capture o sinal SIGTERM e inicie um processo de checkpoint.

No código do aplicativo, é necessário ativar os recursos de verificação automática fornecidos pelo framework de ML. No Pax, por exemplo, isso significa ativar flags de linha de comando ao iniciar o treinamento. Consulte o Guia de início rápido do autocheckpoint com o Pax. Nos bastidores, as estruturas salvam um checkpoint não programado quando um SIGTERM é recebido e a VM da TPU afetada passa por manutenção quando a TPU não está mais disponível em uso.

Guia de início rápido: checkpoint automático com MaxText

MaxText é uma opção LLM arbitrariamente escalonável, de código aberto e bem testado, escrito em Python/JAX puro como alvo de Cloud TPUs". MaxText contém toda a configuração necessária para usar o Autocheckpoint. .

O README do MaxText descreve duas maneiras de executar o MaxText em escala:

Ao usar multihost_runner.py, a única mudança necessária é definir a flag autocheckpoint-enabled ao provisionar o recurso na fila. Ao usar multihost_job.py, a única mudança necessária é especificar o a sinalização de linha de comando ENABLE_AUTOCHECKPOINT=true ao iniciar o job.

Guia de início rápido: checkpoint automático com Pax em frações únicas

Nesta seção, mostramos um exemplo de como configurar e usar o Autocheckpoint com Pax em uma única fatia. Com a configuração apropriada:

  • Um checkpoint será salvo quando ocorrer um evento de manutenção.
  • O Cloud TPU fará a manutenção nas VMs de TPU afetadas após o checkpoint será salvo.
  • Quando a Cloud TPU concluir a manutenção, será possível usar a VM da TPU normalmente.
  1. Use a sinalização autocheckpoint-enabled ao criar a VM da TPU ou na fila.

    Exemplo:

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
  2. Instalar o Pax em uma única fração

    O recurso Autocheckpoint funciona na versão 1.1.0 ou mais recente do Pax. Nas VMs da TPU, Instale a jax[tpu] e a paxml mais recente:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. Inicie o treinamento com a configuração apropriada

    O exemplo a seguir mostra como configurar o modelo LmCloudSpmd2B para salvar pontos de verificação acionados pelo Autocheckpoint em um bucket do Google Cloud Storage:

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Observe as duas sinalizações passadas para o comando:

    • jax_fully_async_checkpoint: com essa flag ativada, orbax.checkpoint.AsyncCheckpointer será usado. A classe AsyncCheckpointer salva automaticamente um ponto de controle quando o script de treinamento recebe um sinal SIGTERM.
    • exit_after_ondemand_checkpoint: com essa flag ativada, os processos da TPU são encerrados depois que o checkpoint automático é salvo, o que aciona a manutenção para ser realizada imediatamente. Se você não usar o treinamento vai continuar depois que o checkpoint for salvo e o Cloud TPU aguardará o tempo limite (cinco minutos) antes de fazer a manutenção necessária.

Guia de início rápido: checkpoint automático com Pax no Multislice

O ponto de verificação automático funciona não apenas para fatias únicas, mas também para várias fatias. Esta seção descreve as etapas necessárias para usar o Autocheckpoint com o Multislice.

  1. Especifique o ponto de verificação automático durante a criação de recursos na fila.

    Um ambiente com várias fatias só pode ser provisionado por uma solicitação de recurso em fila. Assim como no caso de fração única, use a sinalização autocheckpoint-enabled a chamada para criar um recurso na fila.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled

    Consulte o Guia do usuário do Multislice para detalhes sobre todas as opções disponíveis. Depois que a solicitação de recurso em fila for criada e estiver no estado ACTIVE, siga as próximas etapas para executar o Pax com o Autocheckpoint.

  2. Instale o Pax em todas as VMs no ambiente de várias fatias.

    Nas VMs da TPU, instale a jax[tpu] e a paxml mais recente em todas as VMs de TPU no ambiente Multislice:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  3. Inicie o treinamento com a configuração adequada

    Este exemplo mostra como configurar o modelo LmCloudSpmd2B para o ponto de verificação automático ao treinar em um ambiente multislice. Antes ao executar o script de treinamento, defina DCN_MESH_SHAPE como [2, 1, 1], conforme mostrado o seguinte código:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]

    Ao iniciar o treinamento, além das flags de linha de comando discutidas no caso de fatia única, três outras são necessárias:

    • num_hosts: o número total de hosts. Nesse caso, é 2.
    • host_index: o índice do host que inicia o treinamento. Ela varia de 0 a N-1, sendo N o número total de hosts.
    • server_addr: o endereço IP do worker 0 do nó 0, com um porta (por exemplo, 8476). Para encontrar essas informações, use hostname -i no worker 0 do nó 0.

Checkpoint automático com o Orbax

O recurso de checkpoint automático não está limitado a MaxText ou Pax. Qualquer framework que possa capturar o sinal SIGTERM e iniciar um processo de verificação funciona com a infraestrutura fornecida pelo Autocheckpoint. Orbax, um namespace que fornece bibliotecas de utilitários comuns para usuários do JAX oferecem esses recursos.

Conforme explicado na documentação do Orbex, esses recursos são ativados por padrão para os usuários do orbax.checkpoint.CheckpointManager. Método save chamado após cada etapa verifica automaticamente se uma operação evento for iminente e, em caso afirmativo, salva um checkpoint, mesmo que o número da etapa não é um múltiplo de save_interval_steps. A documentação do GitHub (link em inglês) também ilustra como fazer com que o treinamento saia após salvar um Autocheckpoint, com uma modificação no código do usuário.