使用多层级检查点在 GKE 上训练大规模机器学习模型


本页面介绍了如何在 GKE 上训练机器学习模型期间使用多层检查点来可靠地存储和管理检查点。对于使用数千个以上节点的大规模训练作业,检查点存储和管理至关重要。这些大规模作业经常中断(可能每小时中断一次),并且从中断中恢复可能需要很长时间。

优势

使用多层级检查点具有以下优势:

  • 完全编排的检查点数据管理,包括备份、复制和自动恢复,适用于以下工作负载:
  • 从本地节点中存储的检查点快速恢复训练作业。您还可以使用存储在训练集群中另一个节点中的检查点进行恢复。
  • 在最糟糕的情况下(即没有集群内检查点),可从 Cloud Storage 备份中存储的检查点快速恢复训练作业。

准备工作

在开始之前,请确保您已执行以下任务:

  • 启用 Google Kubernetes Engine API。
  • 启用 Google Kubernetes Engine API
  • 如果您要使用 Google Cloud CLI 执行此任务,请安装初始化 gcloud CLI。 如果您之前安装了 gcloud CLI,请运行 gcloud components update 以获取最新版本。

要求

多层检查点需要 GKE 集群版本 1.32.4-gke.1415000 或更高版本。

限制

  • 不支持 Autopilot 集群。

配置 GKE 节点以使用多层检查点

本部分介绍了如何在新的集群和现有集群中配置 GKE 节点。

在新集群上配置节点

  1. 创建一个启用了多层检查点、Cloud Storage FUSE CSI 驱动程序Workload Identity Federation for GKE 的集群。如果您使用 TPU 切片来处理机器学习工作负载,则需要调整集群创建命令,以纳入 TPU 切片节点池的配置。

    gcloud container clusters create CLUSTER_NAME \
        --addons=HighScaleCheckpointing,GcsFuseCsiDriver  \
        --node-locations=NODE_LOCATION \
        --workload-pool=PROJECT_ID.svc.id.goog \
        --cluster-version=CLUSTER_VERSION
        --location=CLUSTER_LOCATION \
        --machine-type=MACHINE_TYPE \
        --num-nodes=NUM_NODES
    

    替换以下值:

    • CLUSTER_NAME:集群的名称。
    • NODE_LOCATION:集群节点的可用区。此处显示您的 TPU 容量。
    • PROJECT_ID:您的 Google Cloud 项目 ID
    • CLUSTER_VERSION:集群的版本。1.32.4-gke.1415000 是支持的最低版本。
    • CLUSTER_LOCATION:要在其中创建集群的区域。
    • MACHINE_TYPE:用于运行 JobSet 控制器和多层级检查点控制器等组件的节点的机器类型。对于大规模训练,我们建议使用至少 e2-standard-4 台机器。您不会使用此机器类型进行模型训练;相反,您将为此目的创建单独的节点池,通常会利用加速器优化的虚拟机系列。
    • NUM_NODES:要在每个集群的可用区中创建的节点数。

配置现有集群中的节点

如需在现有集群中使用多层检查点,请使用以下命令同时启用该功能以及 Cloud Storage FUSE CSI 驱动程序Workload Identity Federation for GKE。现有集群版本必须高于 1.32.3-gke.1170000。

  1. 启用 Workload Identity Federation for GKE

    gcloud container clusters update CLUSTER_NAME \
        --workload-pool=PROJECT_ID.svc.id.goog \
        --location=CLUSTER_LOCATION
    

    替换以下值:

    • CLUSTER_NAME:集群的名称。
    • PROJECT_ID:您的 Google Cloud 项目 ID
    • CLUSTER_LOCATION:集群的区域。
  2. 启用多层级检查点和 Cloud Storage FUSE CSI 驱动程序:

    gcloud container clusters update CLUSTER_NAME \
        --update-addons=HighScaleCheckpointing=ENABLED,GcsFuseCsiDriver=ENABLED \
        --location=CLUSTER_LOCATION
    

配置使用多层级检查点设置的权限

本部分介绍如何配置使用多层级检查点的权限。

授予对 Cloud Storage 存储分区的访问权限

多层级检查点设置 CSI 驱动程序使用的临时卷必须使用现有的 Cloud Storage 存储分区。

如需将检查点存储在 Cloud Storage 存储桶中,多层检查点需要访问该存储桶。向 Kubernetes 服务账号授予存储桶的 Storage Object User (roles/storage.objectUser) IAM 角色,以实现多层级检查点。

gcloud storage buckets add-iam-policy-binding gs://GCS_BUCKET \
    --member "principal://iam.googleapis.com/projects/PROJECT_NUMBER/locations/global/workloadIdentityPools/PROJECT_ID.svc.id.goog/subject/ns/gke-managed-checkpointing/sa/gke-checkpointing-multitier-node" \
    --role "roles/storage.objectUser"

替换以下值:

(可选)授予 Compute Engine 默认服务账号访问权限

如果您的 Compute Engine 实例需要对 Cloud Storage 存储桶的读取权限,请向 Compute Engine 默认服务账号授予 Storage Object Viewer (roles/storage.objectViewer) IAM 角色。

gcloud storage buckets add-iam-policy-binding gs://GCS_BUCKET \
    --member serviceAccount:PROJECT_NUMBER-compute@developer.gserviceaccount.com \
    --role roles/storage.objectViewer

部署 JobSet 控制器

JobSet 控制器负责管理在 GKE 上运行模型训练的批处理作业,并调整其资源分配以高效处理工作负载。确保训练作业启动器部署并使用 JobSet。

如需将 JobSet 部署中管理器容器的内存请求增加到 1 Gi、内存限制增加到 2 Gi,并将 CPU 请求增加到 1,请运行以下补丁命令:

kubectl patch -n jobset-system deploy jobset-controller-manager --type json \
    --patch '[{"op": "add", "path": "/spec/template/spec/containers/0/resources", "value": {"limits": {"memory": "2Gi"}, "requests": {"cpu": "1", "memory": "1Gi"}}}]'

初始化多层级检查点设置 CSI 驱动程序

本部分介绍了如何在工作负载将要运行的节点上初始化多层检查点设置 CSI 驱动程序。CSI 驱动程序负责在模型训练过程中处理检查点的存储和管理。

创建 CheckpointConfiguration

CheckpointConfiguration 是一种 Kubernetes 自定义资源,用于指定部署多层检查点 CSI 驱动程序的属性。此资源是集群级资源。

  1. 创建以下 checkpoint.yaml 清单。

    apiVersion: checkpointing.gke.io/v1
    kind: CheckpointConfiguration
    metadata:
      name: MTC_CONFIG_NAME-configuration
    spec:
        cloudStorageBucketName: GCS_BUCKET
        nodeSelector:
            node.kubernetes.io/instance-type: MACHINE_TYPE
        tolerations:
        - key: TOLERATION_KEY
            operator: Exists
            effect: NoSchedule
        inMemoryVolumeSize: IN_MEMORY_VOLUME_SIZE
        gcsFuseMountOptions:
        - implicit-dirs
        - metadata-cache:negative-ttl-secs:0
        - metadata-cache:ttl-secs:-1
        - metadata-cache:stat-cache-max-size-mb:-1
        - metadata-cache:type-cache-max-size-mb:-1
        - file-cache:max-size-mb:-1
        - file-cache:cache-file-for-range-read:true
        - file-system:kernel-list-cache-ttl-secs:0
        - file-cache:enable-parallel-downloads:true
        - read_ahead_kb=1024
        - write:enable-streaming-writes:true
    

    替换以下内容:

    • MTC_CONFIG_NAME:CheckpointConfiguration 的名称。 此名称对于集群是全局的,不特定于某个作业。
    • GCS_BUCKET:您将存储检查点数据的 Cloud Storage 存储桶的名称。使用您在设置具有权限的 Cloud Storage 存储桶步骤中设置的存储桶。
    • MACHINE_TYPE:相应加速器的机器类型。可以是以下值之一:

      如需详细了解如何在 GKE 中使用 GPU 运行分布式工作负载,请参阅运行多实例 GPU。对于 TPU,请参阅创建 TPU 切片节点池

    • TOLERATION_KEY:此字段允许将 CSI 驱动程序调度到具有匹配污点的节点上。如需详细了解污点在不同加速器类型上的运作方式,请参阅以下页面:

    • IN_MEMORY_VOLUME_SIZE:内存中检查点缓存的大小。指定数量和单位(例如,200 Gi)。此值应为:

      • TPU 的本地检查点大小乘以 2.2
      • 具有单个对等方的 GPU 的本地检查点大小乘以 6.6。
  2. 应用清单:

    kubectl apply -f checkpoint.yaml
    
  3. 检查 CSI 驱动程序是否正在运行:

    kubectl get pod -n gke-managed-checkpointing
    

    输出应类似如下所示:将有多个条目,每个加速节点对应一个条目。

    NAME                                                          READY   STATUS    RESTARTS   AGE
    multitier-driver-e2b033a7-a4e7-496a-87a3-ffd7fcc2e57b-2d4fz   5/5     Running   0          114s
    

卸载多层检查点设置 CSI 驱动程序

如果您想取消部署多层检查点设置 CSI 驱动程序,请删除 CheckpointConfiguration 资源。多层级检查点控制器会从节点中移除 CSI 驱动程序。这样会移除 RAM 磁盘,并释放内存以用于其他工作负载。例如:

kubectl delete -f checkpoint.yaml

管理 Cloud Storage 备份的数据保留和垃圾回收

您负责为检查点的 Cloud Storage 备份实现保留政策。多层级检查点仅将检查点备份写入 Cloud Storage,绝不会修改或删除这些备份。

许多开源工具都可以处理保留和垃圾收集,包括:

以下示例使用 backup-warden,其中 backup 目录已装载到使用 Cloud Storage FUSE 的备份位置:

# Add --delete option to actually delete the backups, as is it only shows what would be deleted (dry-run)
backup-warden -p backup \
    --hourly 24 \
    --daily 7 \
    --weekly 5 \
    --monthly always \
    --yearly always \
    --prefer-recent

更新工作负载 JobSet 清单

更新作业的 JobSet 清单,以纳入大规模检查点卷。具体细节取决于您的工作负载。

例如,如需扩展在 GKE 中部署 TPU 多片中的示例 JobSet,请执行以下步骤:

  1. 将以下代码行添加到 jax-tpu 容器中。

    volumeMounts:
    - name: checkpoint
      mountPath: CHECKPOINT_DIR
    

    CHECKPOINT_DIR 替换为您的检查点目录的路径。 这是生成 replicator.yaml 的位置,也是多层级检查点执行检查点保存操作的位置。如需了解详情,请参阅在应用中集成多层级检查点

  2. 将以下代码行添加到作业规范的 spec.template.spec 字段中。

    volumes:
    - name: checkpoint
      csi:
        driver: multitier-checkpoint.csi.storage.gke.io
    

在应用中集成多层级检查点

如需分享有关检查点位置和复制就绪情况的信息,请修改您的应用以使用以下协议与多层检查点通信。

启动

本部分介绍了应用需要执行的初始步骤,以便与多层检查点机制进行交互。

复制器是多层级检查点设置的核心组件,作为 CSI 驱动程序的一部分在每个节点上运行。复制器负责管理存储层之间的检查点复制,从本地 RAM 磁盘到对等节点,再到 Cloud Storage 等外部存储。

replicator.yaml 文件充当机器学习训练作业(框架代码)与 Replicator 组件之间的动态控制平面。您的机器学习应用会在本地卷 (RAMDisk) 上以编程方式生成此文件,训练作业和 Replicator 服务都可以访问该文件。此清单允许机器学习框架向复制器提供运行时配置和生命周期管理指令,这些指令不同于在后端设置期间定义的静态基础架构参数(例如,Cloud Storage 上传频率)。

如需查看此互动的具体示例,请参阅:

您的应用应在启动期间执行以下步骤:

  1. 等待 replicator.yaml 文件消失,这表示 Replicator 已准备好由您的应用进行配置。replicator.yaml 文件是在更新工作负载 JobSet 清单部分中配置的 CHECKPOINT_DIR 位置生成的。

    首次创建模型训练作业时,replicator.yaml 文件尚不存在,您的应用可以立即继续运行。不过,如果作业已重启(例如,由于发生故障或人工干预),系统可能仍在处理之前的作业实例,并且该实例的 replicator.yaml 可能仍存在于本地卷上。

  2. 您的应用或机器学习作业会创建 replicator.yaml 文件,其中包含类似于以下内容的配置。

    Orbax

    job-name: orbax
    framework: orbax
    assume-data-parallelism: 3
    node-rank: 0
    nodes: 32
    peer-ranks: [1, 16] or peers-per-node: 2
    backup-interval-minutes: 30
    

    PyTorch

    job-name: nemo
    framework: pytorch.distributed
    node-rank: 0
    nodes: 32
    peer-ranks: [1, 16] or peers-per-node: 2
    backup-interval-minutes: 30
    

    此示例配置包含以下字段:

    • name:训练作业的名称。
    • framework:训练作业使用的机器学习框架。
    • node-rank:分布式训练作业中当前节点的唯一标识符。这表示创建相应文件的节点的节点排名。参与运行的每个节点都有自己的排名。
    • nodes:参与分布式训练作业的节点总数。 此值来自 Pod 的元数据。机器学习训练作业也可以查看此值。
    • peer-rankspeers-per-node:指定复制拓扑的两种替代方式。这两个参数只能提供其中一个。
      • peer-ranks:当前节点的检查点数据应复制到的对等节点的显式排名。这样可以精细控制哪些特定节点充当复制伙伴。
      • peers-per-node:复制器应自动选择用于复制的每个节点的对等节点数。
    • backup-interval-minutes:以分钟为单位,检查点备份到 Cloud Storage 的频率。建议您将此值设置为 30 分钟或更长时间。
  3. 等待系统删除新的 replicator.yaml 文件。这表示复制器已重新启动并执行了清理。此步骤可让您在应用执行下一部分中的步骤时,避免本地卷上出现任何过时或临时文件。

从上次已知良好 (LKG) 检查点恢复

  1. 在复制器初始化后,多层级检查点会为每个 TPU 或 GPU 工作器创建一个符号链接。这些符号链接是在与 replicator.yaml 文件相同的已装载本地卷中创建的,作业会将检查点保存在该卷中。

    符号链接的格式为 <job-name>-s{step}-n<node-rank>-w<worker-index>.restore

  2. 从相应的 .restore 文件恢复每个工作器。有关示例,请参阅下一部分中的 Orbax 复制的检查点管理器示例。

保存检查点

在训练作业进行期间,您的应用会多次执行这些步骤。 保存操作会在您在更新工作负载 JobSet 清单中配置的 CHECKPOINT_DIR 位置进行。

Orbax

创建 Orbax 检查点。目录以步骤编号命名。复制器会检测新创建的检查点目录,根据需要执行复制或备份,并自动清理。

如需详细了解如何使用 Orbax 复制器检查点管理器,请参阅 MaxtTest checkpointing 文件。 如需查看复制器服务互动的示例,请参阅 MaxText max_utils 文件

PyTorch

使用 InClusterLocalCheckpointIO 作为自定义 pytorch_lightning.CheckpointIO,以通过本地存储实现正确的分布式检查点。以下示例命令使用基于 NVIDIA NeMo 框架构建的参考实现来启用多层级检查点:

torchrun train.py <other_train_flags> \
    --local-ckpt-dir=CHECKPOINT_DIR \
    --local-ckpt-interval=20 \
    --job-name=JOB_NAME \
    --enable-high-scale-ckpt

替换以下内容:

  • CHECKPOINT_DIR:检查点目录的路径。
  • JOB_NAME:训练作业工作负载的名称。

集群升级

对于集群升级,您可以在升级之前或之后删除并重新创建 CheckpointConfiguration 对象。此操作是必需的,因为此对象动态部署的节点检查点驱动程序 DaemonSet 不会自动升级。

如果您特别希望保持 DaemonSet 规范不变,则无需执行任何操作。

问题排查

本部分提供了问题排查指南,以解决多层级检查点相关问题。 如需了解常规存储问题排查,请参阅排查 GKE 中的 Cloud Storage 问题

未启用多层级检查点

以下错误表明您的集群上未启用多层级检查点:

error: unable to recognize "checkpoint.yaml": no matches for kind "CheckpointConfiguration" in version "checkpointing.gke.io/v1"

创建 CheckpointConfiguration 步骤中运行 kubectl apply -f checkpoint.yaml 后,您可能会遇到此错误。

如需解决此问题,请使用以下命令检查是否已在集群上启用多层检查点:

gcloud container clusters describe CLUSTER_NAME \
    --project PROJECT_ID
    --location CLUSTER_LOCATION

如果启用了多层级检查点,输出应类似如下所示:

addonsConfig:
  gcePersistentDiskCsiDriverConfig:
    enabled: true
  gcsFuseCsiDriverConfig:
    enabled: true
  highScaleCheckpointingConfig:
    enabled: true
  kubernetesDashboard:
    disabled: true
  networkPolicyConfig:
    disabled: true

如果未启用多层级检查点,请更新集群以启用多层级检查点

多层级检查点设置 CSI 驱动程序无法装载卷

如果 CSI 驱动程序无法挂载 Cloud Storage 卷,您可能会遇到此问题。可能有多行类似的内容。

kubectl get pod -n gke-managed-checkpointing
NAME                                                          READY   STATUS     RESTARTS   AGE
multitier-driver-14694e4d-774f-4104-8bba-f0bd82fd7557-5vxr9   0/5     Init:0/1   0          6m32s

如需解决此问题,请检查 CSI 驱动程序 Pod 事件,如以下示例所示:

kubectl describe pod multitier-driver-14694e4d-774f-4104-8bba-f0bd82fd7557-5vxr9 -n gke-managed-checkpointing

Events:
  Type     Reason       Age                 From               Message
  ----     ------       ----                ----               -------
  Normal   Scheduled    17m                 default-scheduler  Successfully assigned gke-managed-checkpointing/multitier-driver-14694e4d-774f-4104-8bba-f0bd82fd7557-5vxr9 to gke-my-cluster-default-pool-353c773f-6d8q
  Warning  FailedMount  82s (x16 over 17m)  kubelet            MountVolume.SetUp failed for volume "gcs" : rpc error: code = PermissionDenied desc = failed to get GCS bucket "checkpointing-test-bucket": googleapi: Error 403: Caller does not have storage.objects.list access to the Google Cloud Storage bucket. Permission 'storage.objects.list' denied on resource (or it may not exist)., forbidden

如果问题是因 Cloud Storage 存储桶 PermissionDenied 错误而导致的(如示例所示),您可以通过正确设置权限来解决此问题。

后续步骤