Efficient and Adaptable Overlapping for Computation and Communication via Signaling and Reordering
摘要
生成模型在多 GPU 计算中取得巨大成功,但 GPU 间通信成为瓶颈,尤其在消费级 GPU 上。利用并发硬件执行重叠计算和通信延迟是缓解通信开销的有效技术。作者指出高效且可适应的重叠设计应满足:(1) tile-wise 重叠以最大化重叠机会;(2) 无干扰计算以保持原始计算性能;(3) 通信不可知以减少针对不同通信原语的开发负担。现有设计无法同时优化这三个特性。本文提出 FlashOverlap,利用新型信号机制:当部分输出完成时,计算 kernel 发送信号触发该部分的通信,同时继续剩余部分的计算。在此之上包含两个关键组件:(1) 信号时机确定以提升重叠效率;(2) 通信前重排创建连续地址,允许直接调用 NCCL API。实验表明 FlashOverlap 通过重叠实现最高 1.65x 加速,在大多数情况下优于现有工作。
背景与动机
多 GPU 并行模式与通信开销
生成模型参数量从数十亿增长到万亿级别 (如 to 2T parameters)。单 GPU 显存无法容纳,需要张量并行 (TP)、流水线并行 (PP)、专家并行 (EP) 和数据并行 (DP)。这些多 GPU 范式不可避免地引入 GPU 间通信开销,主要来自集合通信原语:AllReduce、ReduceScatter 和 All-to-All。消费级 GPU 使用 PCIe 互连 (通常 16-64 GB/s 双向带宽),通信开销更加严重。
GEMM + 通信模式分析
论文 profiling 了多种 LLM 推理和训练场景:
- • GEMM + AllReduce (TP 推理): Llama3-70B (TP=8) 推理中,prefill 和 decode 阶段 GEMM+AR 占总时间约 8.8-8.9%
- • GEMM + ReduceScatter (FSDP 训练): Llama2-7B (TP=4, PP=2) 训练 backward 中 GEMM+RS 占约 30%
- • GEMM + All-to-All (MoE 训练): Mixtral-8x7B (EP=8) 训练 forward 中 GEMM+A2A 占超过 40%
- • Step-Video-T2V 推理 (TP=4): GEMM+AR 占约 9.9%
高效可适应重叠的三个关键特性
(1) Tile-wise 重叠: Tile 是 GEMM 输出中逻辑上最小的并行数据单元,tile-wise 重叠最大化重叠机会。
(2) 无干扰计算: 为保持原始计算性能,应避免对 GEMM 的干扰,包括分割 (segmentation)、分块 (tiling) 或逻辑变更。
(3) 通信不可知: 设计应对通信原语不可知,无需为不同通信原语重复开发。
现有方法对比
Figure 3: 三种重叠方法对比 — (a) Decomposition 法将 GEMM 分割为多个子张量,(b) Fusion 法将通信融合到 GEMM kernel 内,(c) FlashOverlap 信号法保持 GEMM 完整,通过信号触发通信
| 方法 | Tile-wise | 无干扰计算 | 通信不可知 | 代表工作 |
|---|---|---|---|---|
| 分解法 (Decomposition) | 否 | 否 | 是 | CoCoNet, Domino, Async-TP, MegaScale, Centauri |
| 融合法 (Fusion) | 是 | 否 | 否 | FLUX, DistServe, AMD 融合内核 |
| 信号法 (FlashOverlap) | 是 | 是 | 是 | 本文 |
分解法问题: 将 GEMM 输出张量分解为多个子张量以实现异步重叠。但无法实现 tile-wise 重叠 (需要将 GEMM 分割为较小的 GEMM),且干扰原始计算逻辑,可能导致性能退化需要进一步调优。
融合法问题: 将通信原语融合到 GEMM kernel 内部。虽高效但需要重复 GEMM 调优和通信原语实现,每次 GEMM 尺寸变化或通信原语变化都需要重新开发。
FlashOverlap 核心设计
系统总览
GEMM 计算保持在单个 GPU kernel 内,当每组 (G1, G2, G3) tiles 完成时,首先将组内 tiles 重排到连续地址,然后发送信号触发对应的 GPU 间通信。通信完成后,tiles 被重排回正确顺序。
Figure 5: FlashOverlap 信号与重排机制总览 — 左侧展示不同 group 划分下的信号时机,右侧展示 pre-comm 和 post-comm 重排模式
GEMM computation 保持单个 GPU kernel 执行,不做任何分割或逻辑变更。通过信号机制在部分数据完成计算时通知通信层,计算继续处理剩余数据。
通信作为独立 GPU kernel 在不同 stream 上异步执行。通过重排操作确保通信数据地址连续,可直接调用 NCCL API,无需自定义通信实现。
一、信号机制 (Signaling)
动机: 数据依赖性
信号机制用于追踪已完成 GEMM 计算、可通信的数据,同时不对 GEMM 计算产生干扰。理想情况下,当某部分数据完成 GEMM 计算时,一个开销可忽略的信号用于启动对应的通信,而 GEMM kernel 继续计算。这样,信号链式连接数据依赖,无计算干扰和手工 kernel 融合。
关键洞察: GEMM 中的 Wave 模式
论文观察到 GEMM 执行中存在固有的 wave 模式:多个 tiles 几乎同时完成,如同在一个 wave 中。在 RTX 4090 上对 GEMM (M=2048, N=K=8192) 的 profiling 显示,tile 完成时间可明确分为 4 个 distinct waves,与 tile 数量 (512) 除以 SM 数量 (128) 的结果一致。
因此,使用 wave 而非单个 tile 作为重叠单元,本质上相同的重叠机会下实现更好的带宽利用率。
Figure 2: GEMM 执行中的 Wave 模式 — Tile 按完成时间分组为 wave,4 个 wave 对应 512 tiles / 128 SMs
Block Swizzling 的影响
GEMM 中 tile 执行顺序受 block swizzling 技术影响。Block swizzling 以交错方式将 tiles 调度到 SMs 以提升内存访问效率。这导致 wave 中 tile 的内存地址不连续,阻止已完成 tiles 的及时通信。因此需要引入数据重排技术来解决不匹配问题。
Wave Group 设计
静态 wave-wise 信号并非最优。存在更小但即时通信 vs 更大但延迟通信的权衡。因此信号时机设计为可调,在 waves 之上定义 wave group。一个 group G 包含 |G| >= 1 个 waves,对应通信在每个 group 完成 GEMM 计算后开始。每个 group 的大小可调。
实现方法: 组级 Tile 计数
引入计数表 (counting table) 追踪 tiles 完成情况,已完成的 tiles 按组分别记录。计数表大小为 P,表示 tiles 被分为 P 个不同组 (G1, G2, ..., GP)。当 G_j 中的一个 tile 完成时,计数表中第 j 个数字原子性地加 1。当第 j 个数字达到 |G_j| 时,G_j 的通信开始。使用 tile index 来识别 tile 属于哪个组。
二、重排机制 (Reordering)
动机: 通信需要连续地址
通过调用 NCCL 库的单次 GPU 间通信要求发送和接收缓冲区都有连续地址。要实现灵活的 tile 通信,一起通信的 tile 内 (intra-tile) 和 tile 间 (inter-tile) 数据都需要重排到连续地址。
挑战: 不规则的 Tile 执行顺序
除了 tile 内数据天然不连续外,GEMM 中 tile 间执行顺序也不规则。根本原因是 block swizzling 的应用——为优化 GEMM 性能,将 tiles 以交错方式调度到 GPU blocks。这导致同一 wave 中完成的 tiles 在内存地址上不连续。
解决方案: 映射表重排
使用映射表 (mapping table) 将不连续的 tile 索引映射到连续的重排序索引。预通信重排 (pre-communication reordering) 确保待发数据的连续地址,允许直接调用 NCCL API。通信后重排 (post-communication reordering) 纠正数据顺序,恢复原始布局。
不同通信原语的重排模式
Figure 6: 三种通信原语下的重排模式对比 — (a) AllReduce 以 tile 为单元重排,(b) ReduceScatter 以 subtile 为单元沿行维度切分,(c) All-to-All 以 subtoken 为单元按目标 GPU 路由
(1) AllReduce
以 tile 为重排单元。使用映射表将不连续的 tile 索引 (如 tile 0 和 tile 3) 映射到连续的重排序索引 (0 和 1)。通信后通过逆映射表恢复原始顺序。AllReduce 保持每行在所有 GPU 上完整,重排仅改变 tile 的内存布局。
(2) ReduceScatter + AllGather
以 subtile 为重排单元。每个 tile 沿行维度等分为与 GPU 数量相同的 subtiles。无论 tiles 如何分配到 GPU,tile 内的第 k 个 subtile 最终总是驻留在第 k 个 GPU 上,所有第 k 个 subtiles 形成完整的行。映射表相应调整。后续 AllGather 聚合所有行,行顺序可通过本地行交换 (row exchange) 纠正——这只是简单的块循环置换,无需映射表,可与 element-wise 操作无缝融合。
(3) All-to-All
在 MoE 模型中,数据动态路由到特定 GPU 上的专家。以 subtoken 为重排单元。每个 tile 的输出按目标 GPU 分割为 subtokens。预通信重排确保发往同一 GPU 的 subtokens 地址连续。通信后,每个 GPU 上的 subtokens 通过映射表恢复正确顺序。映射表需要考虑动态路由模式。
三、Wave Group 划分设计空间
设计空间大小
每次 wave 后,可选择通信 ("1") 或不通信 ("0")。最后一个 wave 例外——所有累积的 tiles 必须通信。假设有 T 个 waves,设计空间大小为 2^(T-1)。
Figure 7: Wave Group 划分示例 — 5 个 waves 的 4 种不同划分方式,展示不同 group 组合如何影响信号时机和重叠效率
例如 T=5 时,有 2^4 = 16 种选择。一种选择是在 W1, W3, W5 后通信,得到 wave group 划分 (1, 2, 2),即 |G1|=1, |G2|=2, |G3|=2。另一种选择在 W2, W5 后通信,得到划分 (2, 3)。
为什么需要调优?
论文实验发现:在最细粒度划分 (每个 group 包含 1 个 wave) 作为基线的情况下,在 4 张 RTX 4090 上对 50+ 种 GEMM 尺寸使用 AllReduce 原语的测试中,只有 4% 的情况基线划分最优。使用基线划分平均导致 17.34% 性能退化。根本原因是分段通信导致带宽利用率不足,以及频繁 API 调用的开销,当通信延迟占主导时成为性能瓶颈。
调优挑战
原始设计空间大小 2^(T-1),每个候选划分需要在线执行来选择最优,导致不可忽略的调优开销。典型 GEMM (M=4096, N=8192, K=7168) 在 RTX 4090 上产生 T=8 个 waves,等于 128 个候选。在线执行约 5ms,profiling 通常包括 10 次 warm-up 和 100 次计时测试,超过 1 分钟 (模型前向延迟的 100+ 倍),对端到端性能不可接受。
四、实时调优: 预测搜索方法
两阶段调优架构
离线阶段 (Offline): 处理部署设置——获取 GEMM 配置 (tiling size, swizzling pattern, duration)、通信带宽曲线 (bandwidth curve)、SM 资源竞争分析。这些只在部署架构、硬件和网络拓扑固定时执行一次。
在线阶段 (Online): 针对不同 GEMM 尺寸重复调优——基于延迟预测在设计空间中搜索最优划分,无需在线 profiling。
离线阶段详情
GEMM 配置获取: 给定 M x N x K,利用现有高度优化的线性代数实现 (cuBLAS, CUTLASS) 获取 tiling size、swizzling pattern、对应 duration 等配置。
通信带宽曲线采样: 对给定 GPU 执行通信原语,带宽随数据大小连续变化。用多个密集点采样带宽曲线。
SM 资源竞争分析: 确定通信 kernel 占用的 SM 数量,从 GEMM 可用的 SM 中扣除。
在线阶段: 预测搜索算法
基于离线阶段获取的 GEMM 配置和带宽曲线,在线阶段使用延迟预测器替代在线 profiling 来搜索最优 wave group 划分。
对于每个候选划分 G = (G1, G2, ..., GP):
- • 对每个组 G_i,通过插值带宽曲线获取通信延迟 t_m = interp_latency(bdw_curve, data_size)
- • 计算该组的计算延迟 t_p = gemm_config.duration / T * |G_i|
- • 累积延迟: t_acc_m = max(t_acc_p, t_acc_m) + t_m, t_acc_p += t_p
- • 最后加上最后一组的通信延迟
- • 选择总累积延迟最小的划分作为最优解
五、CUTLASS 扩展实现
扩展主循环 (Main Loop)
FlashOverlap 扩展 CUTLASS 矩阵例程,保留核心计算循环,在其中嵌入信号和重排逻辑。在每个 wave 的 tile 完成时,原子性地更新计数表。当计数表指示某个 group 完成时,触发通信前的重排操作。
Epilogue 集成
GEMM 执行包含主循环 (main loop) 和 epilogue。主循环执行核心乘加操作,占 GEMM 大部分时间。Epilogue 指矩阵乘法后执行的 element-wise 操作 (如 ReLU, SiLU, bias addition)。这些操作通常与前面的矩阵乘法融合到单个 GPU kernel 中,消除冗余内存访问和 kernel 启动开销。FlashOverlap 将通信后重排与 element-wise 操作融合,避免额外的 kernel 启动。
专用监控内核
通过原子计数器 (atomic counters) 跟踪进度。独立执行流分别管理计算和传输工作流。信号机制使用 CUDA 事件或共享内存原子变量实现低开销通知。
实验评估结果
典型 GEMM 性能对比
性能对比图 — FlashOverlap 通过最优 wave group 划分,在多种 GEMM 尺寸下均优于基线方法,平均加速比显著
Wave Pattern 验证
在 RTX 4090 上对 GEMM (M=2048, N=K=8192) 的 profiling 显示,tile 完成时间可明确分为 4 个 distinct waves,与 tile 数量 (512) 除以 SM 数量 (128) 的结果一致。完成顺序与内存地址 (tile index) 不对齐,证实了 block swizzling 的影响。
调优必要性验证
在 4 张 RTX 4090 上对 50+ 种 GEMM 尺寸使用 AllReduce 的测试中,只有 4% 的情况最细粒度划分 (每个 group 1 个 wave) 是最优的。使用基线划分平均导致 17.34% 性能退化,证明 wave group 调优的必要性。
预测搜索精度
预测搜索方法通过离线采样的带宽曲线和 GEMM 配置,在线预测各候选划分的延迟,避免了昂贵的在线 profiling (1 分钟+ vs 毫秒级预测),同时保持高精度找到最优或接近最优的划分。
测试硬件与模型
实验覆盖多种硬件和模型场景:
- • NVIDIA A800: Llama3-70B 推理 (TP=8), Mixtral-8x7B 训练 (EP=8), Step-Video-T2V 推理 (TP=4), Llama2-7B 训练 (TP=4, PP=2)
- • NVIDIA RTX 4090: 50+ GEMM 尺寸, AllReduce, 4 GPU 测试
- • 通信原语覆盖: AllReduce, ReduceScatter, All-to-All
核心洞察
1. 信号机制解耦计算与通信
FlashOverlap 的核心创新是使用信号 (signal) 而非分解或融合来链式连接计算-通信数据依赖。GEMM 保持单个 kernel 执行,仅在部分数据完成时发送轻量信号触发通信,实现真正的无干扰计算。
2. Wave 模式是天然的重叠单元
GEMM 执行中 tile 完成时间的 wave 模式 (与 SM 数量对应的批次) 提供了天然的重叠粒度。使用 wave 而非单个 tile 作为信号单元,在几乎不损失重叠机会的情况下提升带宽利用率。
3. Wave Group 调优至关重要
最细粒度划分仅在 4% 情况下最优,平均退化 17.34%。通信带宽利用率和 API 调用开销的权衡使得最优划分随 GEMM 尺寸、通信原语和硬件变化。预测搜索方法使实时调优成为可能。
4. 重排实现通信不可知
预通信重排创建连续地址使直接调用 NCCL API 成为可能,无需为每种通信原语自定义实现。通信后重排恢复原始数据顺序。这对 AllReduce、ReduceScatter 和 All-to-All 均有适配方案。