Cloud TPU 자동 체크포인트 [공개 미리보기]

개요

지금까지 TPU VM에 유지보수가 필요할 때는 사용자가 체크포인트 저장과 같은 진행 상태 보존 작업을 수행할 시간 없이 절차가 즉시 시작되었습니다. 이 내용은 그림 1(a)에 나와 있습니다.

자동 체크포인트

그림 1. 자동 체크포인트 기능 그림: (a) 자동 체크포인트가 없으면 예정된 유지보수 이벤트가 있을 때 마지막 체크포인트의 학습 진행 상태가 손실됩니다. (b) 자동 체크포인트를 사용하면 예정된 유지보수 이벤트가 있을 때 마지막 체크포인트 이후의 학습 진행 상태를 보존할 수 있습니다.

자동 체크포인트(그림 1(b))를 사용하면 유지보수 이벤트가 발생할 때 예약되지 않은 체크포인트를 저장하도록 코드를 구성하여 학습 진행 상태를 보존할 수 있습니다. 유지보수 이벤트가 발생하면 마지막 체크포인트 이후의 진행 상태가 자동으로 저장됩니다. 이 기능은 단일 슬라이스와 멀티슬라이스 모두에서 작동합니다.

자동 체크포인트 기능은 SIGTERM을 캡처하고 이후에 체크포인트를 저장하는 프레임워크에서 작동합니다. 지원되는 프레임워크에는 MaxText, Pax, JAX(Orbax 사용)가 포함됩니다. 추가 프레임워크 지원은 제공되는 대로 공지됩니다.

현재는 Cloud TPU API를 통해 생성된 TPU(v2-v4 및 v5e)만 이 기능을 사용할 수 있습니다. GKE의 TPU 지원은 제공되는 대로 공지됩니다.

자동 체크포인트 사용

자동 체크포인트 기능은 기본적으로 사용 중지되어 있습니다. TPU 또는 큐에 추가된 리소스를 만들 때는 TPU를 프로비저닝할 때 --autocheckpoint-enabled 플래그를 추가하여 사용 설정할 수 있습니다. 이 기능을 사용 설정하면 유지보수 이벤트 알림이 수신되었을 때 Cloud TPU가 다음 단계를 수행합니다.

  1. TPU 기기를 사용해서 전송된 SIGTERM을 진행 상태에 캡처합니다.
  2. 프로세스가 종료되거나 5분이 경과될 때까지 기다리고 해당 슬라이스에 유지보수를 수행합니다.

자동 체크포인트에 사용되는 인프라는 ML 프레임워크에 독립적입니다. SIGTERM 신호를 캡처하고 체크포인트 지정 프로세스를 시작할 수 있는 한 어떤 ML 프레임워크라도 자동 체크포인트를 지원할 수 있습니다.

애플리케이션 코드에서 ML 프레임워크에서 제공된 자동 체크포인트 기능을 사용 설정해야 합니다. 예를 들어 Pax에서는 학습을 시작할 때 명령줄 플래그를 사용 설정해야 합니다(Pax에서 자동 체크포인트 빠른 시작 참조). 이 과정 중에 프레임워크는 SIGTERM이 수신될 때 예약되지 않은 체크포인트를 저장하고 TPU가 더 이상 사용 중이 아닐 때 영향을 받는 TPU VM에 유지보수가 진행됩니다.

빠른 시작: MaxText를 사용한 자동 체크포인트

MaxText는 "Cloud TPU를 대상으로 순수 Python/JAX로 작성되어 임의로 확장 가능하고 잘 테스트된 고성능 오픈소스 LLM"입니다. MaxText에는 자동 체크포인트 기능을 사용하는 데 필요한 모든 설정이 포함됩니다.

MaxText README에서는 규모에 맞게 MaxText를 실행하기 위한 두 가지 방법에 대해 설명합니다.

multihost_runner.py 사용 시 필요한 유일한 변경사항은 큐에 추가된 리소스를 프로비저닝할 때 autocheckpoint-enabled 플래그를 설정하는 것입니다. multihost_job.py 사용 시 필요한 유일한 변경사항은 작업을 시작할 때 ENABLE_AUTOCHECKPOINT=true 명령줄 플래그를 지정하는 것입니다.

빠른 시작: 단일 슬라이스에서 Pax를 사용한 자동 체크포인트

이 섹션에서는 단일 슬라이스에서 Pax와 함께 자동 체크포인트를 설정하고 사용하는 방법에 대한 예시를 제공합니다. 다음과 같이 되도록 적절한 설정을 사용합니다.

  • 유지보수 이벤트가 발생할 때 체크포인트가 저장됩니다.
  • 체크포인트 저장 후 영향을 받는 TPU VM에서 Cloud TPU가 유지보수를 수행합니다.
  • Cloud TPU가 유지보수를 완료하면 일반적으로 TPU VM을 사용할 수 있습니다.
  1. TPU VM 또는 큐에 추가된 리소스를 만들 때 autocheckpoint-enabled 플래그를 사용합니다.

    예를 들면 다음과 같습니다.

    PROJECT=your-gcp-project-name
    ZONE=zone-you-want-to-use
    NODE_ID=your-node-id
    ACCELERATOR_TYPE=your-accelerator-type
    gcloud config set project $PROJECT
    gcloud config set compute/zone $ZONE
    
    gcloud alpha compute tpus tpu-vm create $NODE_ID \
    --accelerator-type $ACCELERATOR_TYPE \
    --version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    
  2. 단일 슬라이스에 Pax 설치

    자동 체크포인트 기능은 Pax 버전 1.1.0 이상에서 작동합니다. TPU VM에서 jax[tpu] 및 최신 paxml을 설치합니다.

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. 적절한 구성으로 학습 시작

    다음 예시에서는 자동 체크포인트로 트리거된 체크포인트를 Google Cloud Storage 버킷에 저장하도록 LmCloudSpmd2B 모델을 구성하는 방법을 보여줍니다.

    JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py
    --jax_fully_async_checkpoint=1 \
    --exit_after_ondemand_checkpoint=1 \
    --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
    --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
    

    명령어에 전달되는 다음 두 가지 플래그에 주의하세요.

    • jax_fully_async_checkpoint: 이 플래그를 설정하면 orbax.checkpoint.AsyncCheckpointer가 사용됩니다. AsyncCheckpointer 클래스는 학습 스크립트에 SIGTERM 신호가 수신될 때 자동으로 체크포인트를 저장합니다.
    • exit_after_ondemand_checkpoint: 이 플래그를 설정하면 자동 체크포인트가 성공적으로 저장된 후 TPU 프로세스가 종료되고, 유지보수가 즉시 수행되도록 트리거됩니다. 이 플래그를 사용하지 않으면 체크포인트가 저장된 후에도 학습이 계속되고 Cloud TPU가 시간 초과(5분)가 발생할 때까지 기다린 후에 필요한 유지보수를 수행합니다.

빠른 시작: 멀티슬라이스에서 Pax를 사용한 자동 체크포인트

자동 체크포인트는 단일 슬라이스뿐만 아니라 멀티슬라이스에도 작동합니다. 이 섹션에서는 멀티슬라이스에서 자동 체크포인트를 사용하는 데 필요한 단계에 대해 자세히 설명합니다.

  1. 큐에 추가된 리소스를 만드는 동안 자동 체크포인트를 지정합니다.

    멀티슬라이스 환경은 큐에 추가된 리소스 요청을 통해서만 프로비저닝할 수 있습니다. 단일 슬라이스 사례와 비슷하게 큐에 추가된 리소스를 만들기 위한 호출에서 autocheckpoint-enabled 플래그를 사용합니다.

    QR_ID=your-qr-id
    NODE_COUNT=your-node-count
    ACCELERATOR_TYPE=your-accelerator-type
    
    gcloud alpha compute tpus queued-resources create $QR_ID \
    --node-count $NODE_COUNT \
    --accelerator-type $ACCELERATOR_TYPE \
    --runtime-version tpu-ubuntu2204-base \
    --autocheckpoint-enabled
    

    사용 가능한 모든 옵션에 대한 자세한 내용은 멀티슬라이스 사용자 가이드를 참조하세요. 큐에 추가된 리소스 요청이 생성되고 ACTIVE 상태가 되면 다음 단계에 따라 자동 체크포인트 기능을 사용해서 Pax를 실행합니다.

  2. 멀티슬라이스 환경에서 모든 VM에 Pax를 설치합니다.

    TPU VM에서 jax[tpu] 및 최신 paxml을 멀티슬라이스 환경의 모든 TPU VM에 설치합니다.

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  3. 적절한 구성으로 학습 시작

    이 예시에서는 멀티슬라이스 환경에서 학습을 수행할 때 자동 체크포인트를 위해 LmCloudSpmd2B 모델을 구성하는 방법을 보여줍니다. 학습 스크립트를 실행하기 전에 다음 코드에 표시된 것처럼 DCN_MESH_SHAPE를 [2, 1, 1]로 설정합니다.

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
    """SPMD model with 2B params.
    
    Global batch size = 2 * 2 * 1 * 32 = 128
    """
    PERCORE_BATCH_SIZE = 8
    
    NUM_LAYERS = 18
    MODEL_DIMS = 3072
    HIDDEN_DIMS = MODEL_DIMS * 4
    
    CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
    ICI_MESH_SHAPE = [1, 4, 1]
    DCN_MESH_SHAPE = [2, 1, 1]
    

    단일 슬라이스 사례에서 설명한 명령줄 플래그 외에도 학습을 시작할 때 다음 3개 항목이 더 필요합니다.

    • num_hosts: 총 호스트 수입니다. 여기에서는 2입니다.
    • host_index: 학습을 시작하는 호스트의 색인입니다. 0부터 N-1까지 다양하며, 여기서 N은 총 호스트 수입니다.
    • server_addr: 사용되지 않은 포트(예: 8476)와 함께 노드 0의 작업자 0의 IP 주소입니다. 이 정보를 찾으려면 노드 0의 작업자 0에서 hostname -i를 사용합니다.

Orbax에서의 자동 체크포인트

자동 체크포인트 기능은 MaxText 또는 Pax로 제한되지 않습니다. SIGTERM 신호를 캡처하고 체크포인트 지정 프로세스를 시작할 수 있는 모든 프레임워크가 자동 체크포인트로 제공되는 인프라를 지원합니다. JAX 사용자를 위한 일반적인 유틸리티 라이브러리를 제공하는 네임스페이스인 Orbax에서도 이러한 기능이 제공됩니다.

Orbax 문서에 설명된 대로 이러한 기능은 orbax.checkpoint.CheckpointManager 사용자에게 기본적으로 사용 설정되어 있습니다. 모든 단계에서 유지보수 이벤트가 임박했는지 여부를 자동으로 확인한 후 호출되는 save 메서드는 단계 번호가 save_interval_steps의 배수가 아니더라도 체크포인트를 저장합니다. 또한 GitHub 문서에서는 사용자 코드 수정과 함께 자동 체크포인트 저장 후 학습을 종료하는 방법을 설명합니다.