FlashOverlap 论文详解

信号驱动计算通信重叠 | EUROSYS '26

最高加速比
1.65x
arXiv:2504.19519

Efficient and Adaptable Overlapping for Computation and Communication via Signaling and Reordering

作者
Ke Hong (Tsinghua/Infinigence-AI), Xiuhong Li*, Minxu Liu, Qiuli Mao, Tianqi Wu, Zixiao Huang, Lufang Chen (Infinigence-AI), Zhong Wang, Yichong Zhang, Zhenhua Zhu (Tsinghua), Guohao Dai* (SJTU/Infinigence-AI), Yu Wang* (Tsinghua)
发表信息
EUROSYS '26, April 27-30, 2026, Edinburgh, Scotland UK
DOI: 10.1145/3767295.3769370

摘要

生成模型在多 GPU 计算中取得巨大成功,但 GPU 间通信成为瓶颈,尤其在消费级 GPU 上。利用并发硬件执行重叠计算和通信延迟是缓解通信开销的有效技术。作者指出高效且可适应的重叠设计应满足:(1) tile-wise 重叠以最大化重叠机会;(2) 无干扰计算以保持原始计算性能;(3) 通信不可知以减少针对不同通信原语的开发负担。现有设计无法同时优化这三个特性。本文提出 FlashOverlap,利用新型信号机制:当部分输出完成时,计算 kernel 发送信号触发该部分的通信,同时继续剩余部分的计算。在此之上包含两个关键组件:(1) 信号时机确定以提升重叠效率;(2) 通信前重排创建连续地址,允许直接调用 NCCL API。实验表明 FlashOverlap 通过重叠实现最高 1.65x 加速,在大多数情况下优于现有工作。

背景与动机

多 GPU 并行模式与通信开销

生成模型参数量从数十亿增长到万亿级别 (如 to 2T parameters)。单 GPU 显存无法容纳,需要张量并行 (TP)、流水线并行 (PP)、专家并行 (EP) 和数据并行 (DP)。这些多 GPU 范式不可避免地引入 GPU 间通信开销,主要来自集合通信原语:AllReduce、ReduceScatter 和 All-to-All。消费级 GPU 使用 PCIe 互连 (通常 16-64 GB/s 双向带宽),通信开销更加严重。

GEMM + 通信模式分析

论文 profiling 了多种 LLM 推理和训练场景:

  • GEMM + AllReduce (TP 推理): Llama3-70B (TP=8) 推理中,prefill 和 decode 阶段 GEMM+AR 占总时间约 8.8-8.9%
  • GEMM + ReduceScatter (FSDP 训练): Llama2-7B (TP=4, PP=2) 训练 backward 中 GEMM+RS 占约 30%
  • GEMM + All-to-All (MoE 训练): Mixtral-8x7B (EP=8) 训练 forward 中 GEMM+A2A 占超过 40%
  • Step-Video-T2V 推理 (TP=4): GEMM+AR 占约 9.9%

高效可适应重叠的三个关键特性

(1) Tile-wise 重叠: Tile 是 GEMM 输出中逻辑上最小的并行数据单元,tile-wise 重叠最大化重叠机会。

(2) 无干扰计算: 为保持原始计算性能,应避免对 GEMM 的干扰,包括分割 (segmentation)、分块 (tiling) 或逻辑变更。

(3) 通信不可知: 设计应对通信原语不可知,无需为不同通信原语重复开发。

现有方法对比

Figure 3: Comparison of overlap methods - Decomposition, Fusion, and Signaling

Figure 3: 三种重叠方法对比 — (a) Decomposition 法将 GEMM 分割为多个子张量,(b) Fusion 法将通信融合到 GEMM kernel 内,(c) FlashOverlap 信号法保持 GEMM 完整,通过信号触发通信

方法Tile-wise无干扰计算通信不可知代表工作
分解法 (Decomposition)CoCoNet, Domino, Async-TP, MegaScale, Centauri
融合法 (Fusion)FLUX, DistServe, AMD 融合内核
信号法 (FlashOverlap)本文

分解法问题: 将 GEMM 输出张量分解为多个子张量以实现异步重叠。但无法实现 tile-wise 重叠 (需要将 GEMM 分割为较小的 GEMM),且干扰原始计算逻辑,可能导致性能退化需要进一步调优。

融合法问题: 将通信原语融合到 GEMM kernel 内部。虽高效但需要重复 GEMM 调优和通信原语实现,每次 GEMM 尺寸变化或通信原语变化都需要重新开发。

FlashOverlap 核心设计

系统总览

GEMM 计算保持在单个 GPU kernel 内,当每组 (G1, G2, G3) tiles 完成时,首先将组内 tiles 重排到连续地址,然后发送信号触发对应的 GPU 间通信。通信完成后,tiles 被重排回正确顺序。

Figure 5: Signaling and reordering mechanism overview

Figure 5: FlashOverlap 信号与重排机制总览 — 左侧展示不同 group 划分下的信号时机,右侧展示 pre-comm 和 post-comm 重排模式

计算侧

GEMM computation 保持单个 GPU kernel 执行,不做任何分割或逻辑变更。通过信号机制在部分数据完成计算时通知通信层,计算继续处理剩余数据。

通信侧

通信作为独立 GPU kernel 在不同 stream 上异步执行。通过重排操作确保通信数据地址连续,可直接调用 NCCL API,无需自定义通信实现。

一、信号机制 (Signaling)

动机: 数据依赖性

信号机制用于追踪已完成 GEMM 计算、可通信的数据,同时不对 GEMM 计算产生干扰。理想情况下,当某部分数据完成 GEMM 计算时,一个开销可忽略的信号用于启动对应的通信,而 GEMM kernel 继续计算。这样,信号链式连接数据依赖,无计算干扰和手工 kernel 融合。

关键洞察: GEMM 中的 Wave 模式

论文观察到 GEMM 执行中存在固有的 wave 模式:多个 tiles 几乎同时完成,如同在一个 wave 中。在 RTX 4090 上对 GEMM (M=2048, N=K=8192) 的 profiling 显示,tile 完成时间可明确分为 4 个 distinct waves,与 tile 数量 (512) 除以 SM 数量 (128) 的结果一致。

因此,使用 wave 而非单个 tile 作为重叠单元,本质上相同的重叠机会下实现更好的带宽利用率。

Figure 2: Wave pattern in GEMM execution - tiles grouped into waves based on completion time

Figure 2: GEMM 执行中的 Wave 模式 — Tile 按完成时间分组为 wave,4 个 wave 对应 512 tiles / 128 SMs

Block Swizzling 的影响

GEMM 中 tile 执行顺序受 block swizzling 技术影响。Block swizzling 以交错方式将 tiles 调度到 SMs 以提升内存访问效率。这导致 wave 中 tile 的内存地址不连续,阻止已完成 tiles 的及时通信。因此需要引入数据重排技术来解决不匹配问题。

Wave Group 设计

静态 wave-wise 信号并非最优。存在更小但即时通信 vs 更大但延迟通信的权衡。因此信号时机设计为可调,在 waves 之上定义 wave group。一个 group G 包含 |G| >= 1 个 waves,对应通信在每个 group 完成 GEMM 计算后开始。每个 group 的大小可调。

实现方法: 组级 Tile 计数

引入计数表 (counting table) 追踪 tiles 完成情况,已完成的 tiles 按组分别记录。计数表大小为 P,表示 tiles 被分为 P 个不同组 (G1, G2, ..., GP)。当 G_j 中的一个 tile 完成时,计数表中第 j 个数字原子性地加 1。当第 j 个数字达到 |G_j| 时,G_j 的通信开始。使用 tile index 来识别 tile 属于哪个组。

// Wave Group 计数表示例 (3 组, 共 8 tiles, 4 waves)
// G1 = {W1} = {tile 0, 1}, |G1| = 2
// G2 = {W2, W3} = {tile 2,3,4,5}, |G2| = 4
// G3 = {W4} = {tile 6, 7}, |G3| = 2
Counting Table 初始化: [0, 0, 0]
W1 完成 (tiles 0,1) -> 计数表 [2, 0, 0] -> G1 计数满 -> 触发 G1 通信
W2 完成 (tiles 2,3) -> 计数表 [2, 2, 0] -> G2 未满,等待
W3 完成 (tiles 4,5) -> 计数表 [2, 4, 0] -> G2 计数满 -> 触发 G2 通信
W4 完成 (tiles 6,7) -> 计数表 [2, 4, 2] -> G3 计数满 -> 触发 G3 通信

二、重排机制 (Reordering)

动机: 通信需要连续地址

通过调用 NCCL 库的单次 GPU 间通信要求发送和接收缓冲区都有连续地址。要实现灵活的 tile 通信,一起通信的 tile 内 (intra-tile) 和 tile 间 (inter-tile) 数据都需要重排到连续地址。

挑战: 不规则的 Tile 执行顺序

除了 tile 内数据天然不连续外,GEMM 中 tile 间执行顺序也不规则。根本原因是 block swizzling 的应用——为优化 GEMM 性能,将 tiles 以交错方式调度到 GPU blocks。这导致同一 wave 中完成的 tiles 在内存地址上不连续。

解决方案: 映射表重排

使用映射表 (mapping table) 将不连续的 tile 索引映射到连续的重排序索引。预通信重排 (pre-communication reordering) 确保待发数据的连续地址,允许直接调用 NCCL API。通信后重排 (post-communication reordering) 纠正数据顺序,恢复原始布局。

不同通信原语的重排模式
Figure 6: Reordering patterns for AllReduce, ReduceScatter, and All-to-All

Figure 6: 三种通信原语下的重排模式对比 — (a) AllReduce 以 tile 为单元重排,(b) ReduceScatter 以 subtile 为单元沿行维度切分,(c) All-to-All 以 subtoken 为单元按目标 GPU 路由

(1) AllReduce

以 tile 为重排单元。使用映射表将不连续的 tile 索引 (如 tile 0 和 tile 3) 映射到连续的重排序索引 (0 和 1)。通信后通过逆映射表恢复原始顺序。AllReduce 保持每行在所有 GPU 上完整,重排仅改变 tile 的内存布局。

(2) ReduceScatter + AllGather

以 subtile 为重排单元。每个 tile 沿行维度等分为与 GPU 数量相同的 subtiles。无论 tiles 如何分配到 GPU,tile 内的第 k 个 subtile 最终总是驻留在第 k 个 GPU 上,所有第 k 个 subtiles 形成完整的行。映射表相应调整。后续 AllGather 聚合所有行,行顺序可通过本地行交换 (row exchange) 纠正——这只是简单的块循环置换,无需映射表,可与 element-wise 操作无缝融合。

(3) All-to-All

在 MoE 模型中,数据动态路由到特定 GPU 上的专家。以 subtoken 为重排单元。每个 tile 的输出按目标 GPU 分割为 subtokens。预通信重排确保发往同一 GPU 的 subtokens 地址连续。通信后,每个 GPU 上的 subtokens 通过映射表恢复正确顺序。映射表需要考虑动态路由模式。

// 预通信重排 (Pre-communication Reordering)
// 输入: GEMM 输出 tiles [tile_0, tile_3, tile_1, tile_2, ...] (地址不连续)
// 映射表: {tile_0 -> 0, tile_3 -> 1, tile_1 -> 2, tile_2 -> 3, ...}
// 输出: 连续缓冲区 [tile_0_data, tile_3_data, tile_1_data, tile_2_data, ...]
// -> 可直接调用 NCCL API: ncclAllReduce(contiguous_buf, ...)
// 通信后重排 (Post-communication Reordering)
// 输入: 通信结果 [result_0, result_1, result_2, result_3, ...]
// 逆映射表: {0 -> tile_0, 1 -> tile_3, 2 -> tile_1, 3 -> tile_2, ...}
// 输出: 恢复原始布局 [result_0 放回 tile_0 位置, result_1 放回 tile_3 位置, ...]

三、Wave Group 划分设计空间

设计空间大小

每次 wave 后,可选择通信 ("1") 或不通信 ("0")。最后一个 wave 例外——所有累积的 tiles 必须通信。假设有 T 个 waves,设计空间大小为 2^(T-1)。

Figure 7: Example of four wave group partitions

Figure 7: Wave Group 划分示例 — 5 个 waves 的 4 种不同划分方式,展示不同 group 组合如何影响信号时机和重叠效率

例如 T=5 时,有 2^4 = 16 种选择。一种选择是在 W1, W3, W5 后通信,得到 wave group 划分 (1, 2, 2),即 |G1|=1, |G2|=2, |G3|=2。另一种选择在 W2, W5 后通信,得到划分 (2, 3)。

为什么需要调优?

论文实验发现:在最细粒度划分 (每个 group 包含 1 个 wave) 作为基线的情况下,在 4 张 RTX 4090 上对 50+ 种 GEMM 尺寸使用 AllReduce 原语的测试中,只有 4% 的情况基线划分最优。使用基线划分平均导致 17.34% 性能退化。根本原因是分段通信导致带宽利用率不足,以及频繁 API 调用的开销,当通信延迟占主导时成为性能瓶颈。

调优挑战

原始设计空间大小 2^(T-1),每个候选划分需要在线执行来选择最优,导致不可忽略的调优开销。典型 GEMM (M=4096, N=8192, K=7168) 在 RTX 4090 上产生 T=8 个 waves,等于 128 个候选。在线执行约 5ms,profiling 通常包括 10 次 warm-up 和 100 次计时测试,超过 1 分钟 (模型前向延迟的 100+ 倍),对端到端性能不可接受。

四、实时调优: 预测搜索方法

两阶段调优架构

离线阶段 (Offline): 处理部署设置——获取 GEMM 配置 (tiling size, swizzling pattern, duration)、通信带宽曲线 (bandwidth curve)、SM 资源竞争分析。这些只在部署架构、硬件和网络拓扑固定时执行一次。

在线阶段 (Online): 针对不同 GEMM 尺寸重复调优——基于延迟预测在设计空间中搜索最优划分,无需在线 profiling。

离线阶段详情

1

GEMM 配置获取: 给定 M x N x K,利用现有高度优化的线性代数实现 (cuBLAS, CUTLASS) 获取 tiling size、swizzling pattern、对应 duration 等配置。

2

通信带宽曲线采样: 对给定 GPU 执行通信原语,带宽随数据大小连续变化。用多个密集点采样带宽曲线。

3

SM 资源竞争分析: 确定通信 kernel 占用的 SM 数量,从 GEMM 可用的 SM 中扣除。

在线阶段: 预测搜索算法

基于离线阶段获取的 GEMM 配置和带宽曲线,在线阶段使用延迟预测器替代在线 profiling 来搜索最优 wave group 划分。

对于每个候选划分 G = (G1, G2, ..., GP):

  • • 对每个组 G_i,通过插值带宽曲线获取通信延迟 t_m = interp_latency(bdw_curve, data_size)
  • • 计算该组的计算延迟 t_p = gemm_config.duration / T * |G_i|
  • • 累积延迟: t_acc_m = max(t_acc_p, t_acc_m) + t_m, t_acc_p += t_p
  • • 最后加上最后一组的通信延迟
  • • 选择总累积延迟最小的划分作为最优解
// Algorithm 1: Grouping Tuning Algorithm (简化伪代码)
// 离线阶段
gemm_config = get_config(M, N, K, gpu) // 获取 GEMM 配置
T = gemm_config.tile_num / (gpu.sm_num - comm_op.sm_num) // wave 数量
bdw_curve = sample_bandwidth(comm_op, gpu) // 采样带宽曲线
// 在线阶段: 预测搜索
candidates = get_candidates(T) // 生成候选划分
t_min = +inf
for G in candidates:
t_acc_p = 0, t_acc_m = 0 // 累积计算/通信延迟
for i, G_i in enumerate(G):
data_size = get_data_size(G_{i-1})
t_m = interp_latency(bdw_curve, data_size) // 预测通信延迟
t_p = gemm_config.duration / T * |G_i| // 计算延迟
t_acc_m = max(t_acc_p, t_acc_m) + t_m
t_acc_p = t_acc_p + t_p
# 加上最后一组通信延迟
t_acc_m += interp_latency(bdw_curve, get_data_size(G_{-1}))
if t_acc_m < t_min:
t_min = t_acc_m, G_optimal = G
return G_optimal

五、CUTLASS 扩展实现

扩展主循环 (Main Loop)

FlashOverlap 扩展 CUTLASS 矩阵例程,保留核心计算循环,在其中嵌入信号和重排逻辑。在每个 wave 的 tile 完成时,原子性地更新计数表。当计数表指示某个 group 完成时,触发通信前的重排操作。

Epilogue 集成

GEMM 执行包含主循环 (main loop) 和 epilogue。主循环执行核心乘加操作,占 GEMM 大部分时间。Epilogue 指矩阵乘法后执行的 element-wise 操作 (如 ReLU, SiLU, bias addition)。这些操作通常与前面的矩阵乘法融合到单个 GPU kernel 中,消除冗余内存访问和 kernel 启动开销。FlashOverlap 将通信后重排与 element-wise 操作融合,避免额外的 kernel 启动。

专用监控内核

通过原子计数器 (atomic counters) 跟踪进度。独立执行流分别管理计算和传输工作流。信号机制使用 CUDA 事件或共享内存原子变量实现低开销通知。

实验评估结果

1.65x
最高加速比 (通过重叠实现)
4%
基线划分 (最细粒度) 最优比例
17.34%
基线划分平均性能退化

典型 GEMM 性能对比

GEMM performance comparison across different overlap methods

性能对比图 — FlashOverlap 通过最优 wave group 划分,在多种 GEMM 尺寸下均优于基线方法,平均加速比显著

Wave Pattern 验证

在 RTX 4090 上对 GEMM (M=2048, N=K=8192) 的 profiling 显示,tile 完成时间可明确分为 4 个 distinct waves,与 tile 数量 (512) 除以 SM 数量 (128) 的结果一致。完成顺序与内存地址 (tile index) 不对齐,证实了 block swizzling 的影响。

调优必要性验证

在 4 张 RTX 4090 上对 50+ 种 GEMM 尺寸使用 AllReduce 的测试中,只有 4% 的情况最细粒度划分 (每个 group 1 个 wave) 是最优的。使用基线划分平均导致 17.34% 性能退化,证明 wave group 调优的必要性。

预测搜索精度

预测搜索方法通过离线采样的带宽曲线和 GEMM 配置,在线预测各候选划分的延迟,避免了昂贵的在线 profiling (1 分钟+ vs 毫秒级预测),同时保持高精度找到最优或接近最优的划分。

测试硬件与模型

实验覆盖多种硬件和模型场景:

  • NVIDIA A800: Llama3-70B 推理 (TP=8), Mixtral-8x7B 训练 (EP=8), Step-Video-T2V 推理 (TP=4), Llama2-7B 训练 (TP=4, PP=2)
  • NVIDIA RTX 4090: 50+ GEMM 尺寸, AllReduce, 4 GPU 测试
  • • 通信原语覆盖: AllReduce, ReduceScatter, All-to-All

核心洞察

1. 信号机制解耦计算与通信

FlashOverlap 的核心创新是使用信号 (signal) 而非分解或融合来链式连接计算-通信数据依赖。GEMM 保持单个 kernel 执行,仅在部分数据完成时发送轻量信号触发通信,实现真正的无干扰计算。

2. Wave 模式是天然的重叠单元

GEMM 执行中 tile 完成时间的 wave 模式 (与 SM 数量对应的批次) 提供了天然的重叠粒度。使用 wave 而非单个 tile 作为信号单元,在几乎不损失重叠机会的情况下提升带宽利用率。

3. Wave Group 调优至关重要

最细粒度划分仅在 4% 情况下最优,平均退化 17.34%。通信带宽利用率和 API 调用开销的权衡使得最优划分随 GEMM 尺寸、通信原语和硬件变化。预测搜索方法使实时调优成为可能。

4. 重排实现通信不可知

预通信重排创建连续地址使直接调用 NCCL API 成为可能,无需为每种通信原语自定义实现。通信后重排恢复原始数据顺序。这对 AllReduce、ReduceScatter 和 All-to-All 均有适配方案。