介绍 JobSet

作者: Daniel Vega-Myhre (Google), Abdullah Gharaibeh (Google), Kevin Hannon (Red Hat)

在本文中,我们介绍 JobSet,这是一个用于表示分布式作业的开源 API。JobSet 的目标是为 Kubernetes 上的分布式 ML 训练和 HPC 工作负载提供一个统一的 API。

为什么需要 JobSet?

Kubernetes 社区近期对 Kubernetes 上批处理生态系统的增强吸引了 ML 工程师,他们发现 Kubernetes 非常适合运行分布式训练工作负载的需求。

无法放入单个主机 GPU 或 TPU 芯片内存的大型 ML 模型(特别是 LLM)通常会分布到数万个加速器芯片上,这些芯片又可能跨越数千台主机。

因此,模型训练代码通常会被容器化并在所有这些主机上同时执行,执行分布式计算,这些计算通常会将模型参数和/或训练数据集分片到目标加速器芯片上,使用 all-gather 和 all-reduce 等集合通信原语执行分布式计算并在主机之间同步梯度。

这些工作负载特性使得 Kubernetes 非常适合此类工作负载,因为它在计算资源集群中高效地调度和管理容器化应用的生命周期方面表现出色。

它也非常具有可扩展性,允许开发者定义自己的 Kubernetes API、对象和控制器来管理这些对象的行为和生命周期,从而让工程师能够开发定制的分布式训练编排解决方案来满足他们的需求。

然而,随着分布式 ML 训练技术的不断发展,现有的 Kubernetes 原语单独已不足以对其进行充分建模。

此外,Kubernetes 分布式训练编排 API 的生态已经变得碎片化,并且这个碎片化生态中的现有解决方案都有某些局限性,使其对于分布式 ML 训练来说并非最优选择。

例如,KubeFlow 训练 operator 为不同的框架定义了定制 API(例如 PyTorchJob、TFJob、MPIJob 等);然而,这些作业类型实际上都是专门针对目标框架的解决方案,每种作业类型都有不同的语义和行为。

另一方面,Job API 修复了运行批处理工作负载的许多不足,包括 Indexed completion mode、更高的可伸缩性、Pod 失败策略和 Pod 退避策略等一些最新的增强功能。然而,使用上游 Job API 运行 ML 训练和 HPC 工作负载需要额外的编排来弥补以下不足之处:

多模板 Pod:大多数 HPC 或 ML 训练作业包含不止一种类型的 Pod。不同的 Pod 是同一工作负载的一部分,但它们需要运行不同的容器、请求不同的资源或具有不同的失败策略。一个常见的例子是驱动器-工作器模式。

作业组:大规模训练工作负载跨越多种网络拓扑,例如运行在多个机架上。此类工作负载对网络延迟敏感,旨在本地化通信并最大程度地减少跨高延迟网络链路的流量。为了促进这一点,需要将工作负载拆分为 Pod 组,每个组分配给一个网络拓扑。

Pod 间通信:创建和管理建立作业 Pod 之间通信所需的资源(例如 无头服务)。

启动顺序:某些作业需要特定的 Pod 启动顺序;有时期望驱动器首先启动(如 Ray 或 Spark),在其他情况下则期望工作器在启动驱动器之前准备就绪(如 MPI)。

JobSet 旨在利用 Job API 作为构建模块来弥补这些不足之处,从而为大规模分布式 HPC 和 ML 用例构建更丰富的 API。

JobSet 的工作原理

JobSet 将分布式批处理工作负载建模为一组 Kubernetes Job。这允许用户轻松地为不同类型的 Pod 组(例如 leader、workers、参数服务器等)指定不同的 Pod 模板。

它使用 ReplicatedJob 抽象来管理子 Job,其中 ReplicatedJob 本质上是一个 Job 模板,指定了所需的 Job 副本数量。这提供了一种声明式的方式,可以轻松创建相同的子 Job 以在不同的加速器岛上运行,而无需借助脚本或 Helm charts 生成许多名称不同的相同作业版本。

JobSet Architecture

JobSet 解决上述问题的其他一些关键特性包括:

Replicated Jobs:在现代数据中心,GPU 和 TPU 等硬件加速器分配在通过专用高带宽网络链接连接的同构加速器岛中。例如,用户可以配置包含一组位于同一机架上的主机节点,每台主机都配有 H100 GPU,其中每台主机内的 GPU 芯片通过 NVLink 连接,并由 NVLink Switch 连接多个 NVLink。TPU Pods 是另一个例子:TPU ViperLitePods 由 64 台主机组成,每台主机连接有 4 个 TPU v5e 芯片,所有芯片都通过 ICI 网格连接。当跨多个此类岛屿运行分布式训练作业时,我们通常希望将工作负载划分为一组较小的、完全相同的作业,每个岛屿一个作业,其中每个 Pod 主要与同一岛屿内的 Pod 进行通信以执行分布式计算的片段,并将通过 DCN(数据中心网络,带宽低于 ICI)进行的梯度同步保持在最低限度。

自动无头服务创建、配置和生命周期管理:通过 Pod 主机名进行 Pod 间通信默认启用,支持此功能的无头服务进行自动配置和生命周期管理。

可配置的成功策略:JobSet 具有可配置的成功策略,这些策略针对特定的 ReplicatedJobs,并提供操作符来针对其子作业的“Any”或“All”。例如,您可以配置 JobSet,使其仅当属于“worker”ReplicatedJob 的所有 Pod 都完成后才标记为完成。

可配置的失败策略:JobSet 具有可配置的失败策略,允许用户指定 JobSet 在发生故障时应重启的最大次数。如果任何作业被标记为失败,整个 JobSet 将被重新创建,从而允许工作负载从上一个检查点恢复。如果未指定失败策略,则任何作业失败时,JobSet 将直接失败。

每个拓扑域的独占放置:JobSet 允许用户表达子作业与拓扑域(通常是机架等加速器岛)之间存在 1:1 的独占分配关系。例如,如果 JobSet 创建了两个子作业,则此特性将强制规定每个子作业的 Pod 将共置在同一岛屿上,并且每个岛屿只允许调度一个子作业。这对于我们希望使用分布式数据并行 (DDP) 训练策略,利用多个计算资源岛(GPU 机架或 TPU 切片)训练模型,并在每个加速器岛中运行 1 个模型副本的场景非常有用,这样可以确保模型副本内部的前向和后向传递本身发生在岛屿内连接加速器芯片的高带宽互连网络上,而模型副本之间的梯度同步仅通过低带宽数据中心网络跨加速器岛屿发生。

与 Kueue 集成:用户可以通过 Kueue 提交 JobSets,从而实现集群超额订阅、在容量可用时排队运行工作负载、防止部分调度和死锁、启用多租户等功能。

示例用例

使用 Jax 在多个 TPU 切片上进行分布式 ML 训练

以下示例是一个 JobSet 规范,用于在 4 个 TPU v5e 切片上运行 TPU 多切片工作负载。要了解更多关于 TPU 的概念和术语,请参阅这些文档

本例使用 Jax,这是一个 ML 框架,通过 OpenXLA 原生支持针对 TPU 芯片的即时 (JIT) 编译。不过,您也可以使用 PyTorch/XLA 在 TPU 上进行 ML 训练。

本例利用了 JobSet 的几项特性(显式的和隐式的),以便开箱即用地支持 TPU 多切片训练独特的调度需求,用户只需很少的配置。

# Run a simple Jax workload on 
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice
  annotations:
    # Give each child Job exclusive usage of a TPU slice 
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 3
  replicatedJobs:
  - name: workers
    replicas: 4 # Set to number of TPU slices
    template:
      spec:
        parallelism: 2 # Set to number of VMs per TPU slice
        completions: 2 # Set to number of VMs per TPU slice
        backoffLimit: 0
        template:
          spec:
            hostNetwork: true
            dnsPolicy: ClusterFirstWithHostNet
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
              cloud.google.com/gke-tpu-topology: 2x4
            containers:
            - name: jax-tpu
              image: python:3.8
              ports:
              - containerPort: 8471
              - containerPort: 8080
              securityContext:
                privileged: true
              command:
              - bash
              - -c
              - |
                pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                python -c 'import jax; print("Global device count:", jax.device_count())'
                sleep 60                
              resources:
                limits:
                  google.com/tpu: 4

未来工作和参与方式

我们计划在今年开发 JobSet 路线图上的许多特性,这些特性可以在 JobSet 路线图中找到。

欢迎随时提供任何形式的反馈。我们也欢迎更多贡献者加入,无论是修复或报告 bug,还是帮助添加新特性或编写文档。

您可以通过我们的 仓库邮件列表或在 Slack 上联系我们。

最后但同样重要的是,感谢所有使这个项目成为可能的贡献者们