接口用例:model Training or Inference Prefilling

README 文档的 ### Example use in model training or inference prefilling 部分,定义了多个函数,这些函数主要用于模型训练或推理预填充阶段,涉及通信缓冲区管理、数据分发和合并等操作。

1. get_buffer

其核心作用是管理并初始化一个全局的通信缓冲区(Buffer,确保缓冲区大小满足分布式通信需求(如 MoE 中的专家路由),并在必要时重新分配缓冲区。以下是详细解释:

  • 作用:该函数用于获取或初始化一个 Buffer 实例。首先,根据分发和合并操作的配置,计算 NVLink 和 RDMA 所需的缓冲区大小。然后,检查全局变量 _buffer 是否存在,或者其配置是否满足当前需求。如果不满足,则重新创建一个新的 Buffer 实例。最后返回这个 Buffer 实例。

1. 函数定义与参数

 
def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:
 
  • 功能:获取或创建一个 deep_ep.Buffer 实例,用于分布式环境下的专家并行(MoE)通信。
  • 参数
  • grouptorch.distributed.ProcessGroup 对象,定义参与通信的进程组(如节点内所有进程)。
  • hidden_bytes:单个 token 的隐藏层数据大小(字节),用于计算缓冲区所需的内存空间。

2. 全局缓冲区变量

 
global _buffer
 
  • _buffer 是全局变量,存储唯一的 Buffer 实例,避免重复创建缓冲区(优化性能)。

3. 计算缓冲区大小需求

 
# 获取调度(dispatch)和聚合(combine)操作的配置
 
num_nvl_bytes, num_rdma_bytes = 0, 0
 
for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())):
 
# 计算 NVL 缓冲区大小(节点内通信,如 GPU 间共享内存)
 
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
 
# 计算 RDMA 缓冲区大小(远程直接内存访问,用于跨节点通信)
 
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
 
  • 配置获取
  • Buffer.get_dispatch_config(group.size()):获取“专家调度”(将 token 分发到对应专家)的配置。
  • Buffer.get_combine_config(group.size()):获取“结果聚合”(将专家输出聚合回原进程)的配置。

(注:配置中包含通信优化参数,如分块大小,可通过测试自动调优,见代码注释)。

  • 缓冲区大小计算
  • config.get_nvl_buffer_size_hint(…):根据 hidden_bytes 和进程数(group.size()),估算节点内通信(NVL)所需的最小缓冲区大小(字节)。
  • config.get_rdma_buffer_size_hint(…):类似地,估算跨节点通信(RDMA)所需的最小缓冲区大小。
  • max(…):确保缓冲区大小满足调度和聚合两种操作的最大需求。

4. 缓冲区初始化或复用

 
# 若缓冲区未创建、进程组不匹配或大小不足,则重新分配
 
if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes:
 
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
 
return _buffer
 
  • 条件判断:仅在以下情况重新创建 Buffer
  1. _buffer 未初始化(首次调用)。
  2. 进程组 group 变更(如通信范围改变)。
  3. 当前缓冲区的 NVL 或 RDMA 空间不足(num_nvl_bytes/num_rdma_bytes 大于现有大小)。
  • 创建缓冲区

Buffer(group, num_nvl_bytes, num_rdma_bytes) 初始化一个新缓冲区,分配 num_nvl_bytes(NVL 内存)和 num_rdma_bytes(RDMA 内存),并绑定到 group 进程组。

  • 返回值:返回全局 _buffer 实例,确保整个程序中使用同一个缓冲区进行通信。

总结

get_buffer缓冲区的“管家”函数,通过以下逻辑确保高效通信:

  1. 按需分配:根据通信操作(调度/聚合)和数据大小(hidden_bytes)动态计算所需内存。
  2. 复用优化:避免重复创建缓冲区,仅在必要时(如大小不足、进程组变更)重新分配。
  3. 适配分布式:通过 group 和配置参数,确保缓冲区兼容当前分布式环境(进程数、通信模式)。

2. get_hidden_bytes

其作用是计算单个 token 的隐藏层数据所占用的字节数,用于后续通信缓冲区大小的估算(如 get_buffer 函数中计算 num_nvl_bytesnum_rdma_bytes)。

  • 作用:计算输入张量 x 的隐藏字节数。如果 x 是元组,则取元组的第一个元素;否则直接使用 x。然后计算该张量第二维的大小,并乘以元素的字节大小(至少为 2),得到隐藏字节数。

1. 函数定义与参数

 
def get_hidden_bytes(x: torch.Tensor) -> int:
 
  • 功能:计算单个 token 的隐藏层数据大小(字节)。
  • 参数x 是输入的隐藏层张量(或包含隐藏层数据的元组,如 FP8 格式可能包含数据和缩放因子)。

2. 处理输入数据格式

 
t = x[0] if isinstance(x, tuple) else x
 
  • 逻辑
  • x 是元组(如 FP8 格式的隐藏层数据,可能包含 (data_tensor, scale_tensor)),则取元组的第一个元素 x[0](即实际存储隐藏层数据的张量)。
  • x 是普通张量(如 BF16/FP32 格式),则直接使用 x
  • 目的:统一数据格式,确保后续计算基于隐藏层数据张量本身,而非附加信息(如缩放因子)。

3. 计算字节数

 
return t.size(1) * max(t.element_size(), 2)
 
  • 分步解析
  1. t.size(1):获取隐藏层的维度大小(即单个 token 的特征数)。
  • 假设 t 的形状为 (num_tokens, hidden_dim)(批量 token 的隐藏层数据),则 t.size(1) 对应 hidden_dim(如 7168,常见于大语言模型)。
  1. t.element_size():获取张量元素的字节数(如 BF16 为 2 字节,FP32 为 4 字节)。
  2. max(t.element_size(), 2):确保每个元素至少按 2 字节计算。
  • 原因:即使元素类型小于 2 字节(如理论上的 FP4),通信缓冲区仍需按最小 2 字节对齐(硬件或驱动要求),避免内存访问错误。
  1. 乘积结果hidden_dim * 每个元素字节数,即单个 token 的隐藏层数据总字节数。

示例与场景

  • 场景 1:BF16 张量(x 是普通张量)
  • t.element_size() = 2(BF16 占 2 字节),max(2, 2) = 2
  • t.size(1) = 7168(隐藏维度),则 hidden_bytes = 7168 * 2 = 14336 字节/ token。
  • 场景 2:FP8 元组(x(data_tensor, scale_tensor)
  • t = x[0](取数据张量),假设 t.element_size() = 1(FP8 占 1 字节),max(1, 2) = 2(按 2 字节对齐)。
  • t.size(1) = 7168,则 hidden_bytes = 7168 * 2 = 14336 字节/ token。

总结

get_hidden_bytes 是通信缓冲区大小计算的基础工具函数,通过:

  1. 统一处理不同数据格式(张量/元组),提取核心数据张量;
  2. 结合隐藏维度和元素字节数(并强制最小 2 字节对齐),计算单个 token 的隐藏层数据字节数。

其结果直接用于 get_buffer 函数,确保分配的缓冲区足以容纳分布式通信中的数据传输需求。

3. dispatch_forward

其核心作用是在 MoE(Mixture of Experts)模型中执行“专家调度”逻辑,即将输入 token 根据 topk_idx(选中的专家索引)分发到对应进程/专家,并返回调度后的结果。以下是详细解析:

  • 作用:执行混合专家(MoE)分发操作。首先调用 _buffer.get_dispatch_layout 计算分发布局,包括每个进程的令牌数、每个 RDMA 进程的令牌数等。然后调用 _buffer.dispatch 方法进行实际的分发操作,返回接收的数据、接收的 Top-K 索引和权重、每个专家接收的令牌数列表、句柄和事件对象。

1. 函数定义与核心功能

 
def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
 
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
 
num_experts: int, previous_event: Optional[EventOverlap] = None) -> \
 
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]:
 
  • 功能:实现 MoE 前向传播中的“token 到专家的路由”,将输入数据 x 根据 topk_idx 分发到对应专家所在的进程,并返回专家处理前的中间数据。
  • 应用场景:大语言模型训练或推理的预填充阶段(prefilling),负责专家并行中的跨进程通信调度。

2. 参数解析

| 参数名 | 类型 | 说明 |

|----------------------|---------------------------------------|----------------------------------------------------------------------|

| x | 张量或张量元组 | 输入的隐藏层数据(如 BF16 张量或 FP8 格式的 (data, scale) 元组)。|

| topk_idx | torch.Tensor | 形状 (num_tokens, num_topk),每个 token 选中的 num_topk 个专家索引。|

| topk_weights | torch.Tensor | 形状 (num_tokens, num_topk),选中专家的权重(用于后续加权求和)。|

| num_experts | int | 专家总数(全局)。|

| previous_event | Optional[EventOverlap] | 可选的 CUDA 事件依赖,用于通信 - 计算重叠(如等待前一个 kernel 完成后再调度)。|

3. 关键逻辑步骤

步骤 1:声明全局通信缓冲区

 
global _buffer
 
  • _buffer 是全局 deep_ep.Buffer 实例,负责管理分布式通信所需的内存和通信资源(在 get_buffer 中初始化)。

步骤 2:计算调度布局(Layout)

 
# 计算 token 分发的元数据(如每个进程/专家的 token 数量、是否属于当前进程等)
 
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = \
 
_buffer.get_dispatch_layout(topk_idx, num_experts,
 
previous_event=previous_event, async_finish=True,
 
allocate_on_comm_stream=previous_event is not None)
 
  • 功能:通过 _buffer.get_dispatch_layout 预计算 token 分发的关键元数据,为实际通信做准备。
  • 返回值解析
  • num_tokens_per_rank:每个进程(rank)需要处理的 token 数量。
  • num_tokens_per_rdma_rank:跨节点通信(RDMA)中每个进程的 token 数量。
  • num_tokens_per_expert:每个专家需要处理的 token 数量。
  • is_token_in_rank:布尔张量,标记每个 token 是否属于当前进程。
  • previous_event:更新后的 CUDA 事件(若传入 previous_event,则等待其完成后再继续)。

步骤 3:执行实际调度(Dispatch)

 
# 执行 MoE 调度,将 token 分发到对应专家/进程
 
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
 
_buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
 
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
 
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,
 
previous_event=previous_event, async_finish=True,
 
allocate_on_comm_stream=True)
 
  • 核心操作:调用 _buffer.dispatch 完成实际的跨进程数据传输,将输入 xtopk_idxtopk_weights 分发到对应专家所在的进程。
  • 关键参数
  • async_finish=True:启用异步调度,通信操作不阻塞 CPU(优化吞吐量)。
  • allocate_on_comm_stream=True:在通信流(而非计算流)上分配内存,避免占用计算资源。
  • 布局元数据(num_tokens_per_rank 等):指导如何高效分发数据。

步骤 4:返回调度结果

 
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event
 
  • 返回值解析
  • recv_x:当前进程接收到的、需要由本地专家处理的 token 数据(格式与输入 x 一致,如张量或 FP8 元组)。
  • recv_topk_idx/recv_topk_weights:分发后当前进程本地专家对应的专家索引和权重。
  • num_recv_tokens_per_expert_list:列表,每个元素为当前进程中每个本地专家接收到的 token 数量。
  • handle:调度句柄(元组),包含后续聚合(combine)所需的元数据(如 token 前缀矩阵)。
  • eventEventOverlap 对象,用于跟踪调度操作的 CUDA 事件(便于后续同步或重叠)。

4. 关键特性与注意事项

  • 通信 - 计算重叠:通过 previous_event 可指定依赖的 CUDA 事件(如等待前一个计算 kernel 完成),实现通信与计算的并行。
  • 异步执行async_finish=True 使调度操作异步执行,CPU 无需等待 GPU 通信完成即可继续后续逻辑(需通过 event 显式同步)。
  • CUDA 图兼容性:注释提到“CPU 会等待 GPU 信号”,因此默认不兼容 CUDA 图(除非指定 num_worst_tokens,但仅限节点内场景)。

总结

dispatch_forward 是 MoE 前向传播的核心函数,通过以下流程实现 token 到专家的高效分发:

  1. 布局计算:预计算 token 分布 metadata(进程/专家的 token 数量)。
  2. 实际调度:基于 metadata 调用缓冲区通信接口,跨进程传输数据。
  3. 返回结果:提供本地专家所需的输入数据、元数据及事件句柄,为后续专家计算和结果聚合(combine_forward)做准备。

其设计优化了分布式通信效率,支持异步/重叠操作,是 DeepEP 库实现高效专家并行的关键组件。

4. dispatch_backward

其核心作用是实现 MoE(Mixture of Experts)模型中“调度前向”(dispatch_forward)的反向传播逻辑,通过聚合各专家返回的梯度,生成原始输入的梯度。以下是详细解析:

  • 作用:执行 MoE 分发操作的反向过程,实际上是一个合并操作。调用 _buffer.combine 方法,将梯度数据和 Top-K 权重梯度进行合并,返回合并后的梯度数据、合并后的 Top-K 权重梯度和事件对象。

1. 函数定义与核心功能

 
def dispatch_backward(grad_recv_x: torch.Tensor, grad_recv_topk_weights: torch.Tensor, handle: Tuple) -> \
 
Tuple[torch.Tensor, torch.Tensor, EventOverlap]:
 
  • 功能:在 MoE 反向传播中,将专家处理后的梯度(grad_recv_xgrad_recv_topk_weights)聚合回原进程,生成与 dispatch_forward 输入对应的梯度(如原始 xtopk_weights 的梯度)。
  • 本质:前向传播的“调度”(dispatch_forward)将 token 分发到专家,反向传播则需将专家的梯度“聚合”回原进程,因此调度的反向过程本质是聚合操作(见代码注释:“The backward process of MoE dispatch is actually a combine”)。

2. 参数解析

| 参数名 | 类型 | 说明 |

|-------------------------|-----------------------|----------------------------------------------------------------------|

| grad_recv_x | torch.Tensor | 专家处理后返回的隐藏层梯度(形状与 dispatch_forward 返回的 recv_x 一致)。|

| grad_recv_topk_weights| torch.Tensor | 专家处理后返回的 topk_weights 梯度(形状与 dispatch_forward 返回的 recv_topk_weights 一致)。|

| handle | Tuple | 调度句柄,由 dispatch_forward 返回,包含聚合所需的元数据(如 token 路由信息、进程前缀矩阵等)。|

3. 关键逻辑步骤

步骤 1:声明全局通信缓冲区

 
global _buffer
 
  • _buffer 是全局 deep_ep.Buffer 实例,负责管理分布式通信资源(与 dispatch_forward 共用,避免重复初始化)。

步骤 2:调用 combine 聚合梯度

 
combined_grad_x, combined_grad_recv_topk_weights, event = \
 
_buffer.combine(grad_recv_x, handle, topk_weights=grad_recv_topk_weights, async_finish=True)
 
  • 核心操作:通过 _buffer.combine 聚合梯度,具体逻辑为:
  1. 输入梯度grad_recv_x(隐藏层梯度)和 grad_recv_topk_weights(权重梯度)是专家处理后返回的局部梯度。
  2. 句柄依赖handle 提供聚合所需的路由元数据(如哪些 token 来自哪个进程),确保梯度按前向调度的逆过程聚合。
  3. 异步执行async_finish=True 使聚合操作异步执行(CPU 无需等待 GPU 通信完成,通过 event 跟踪状态)。

步骤 3:返回聚合结果

 
return combined_grad_x, combined_grad_recv_topk_weights, event
 
  • 返回值解析
  • combined_grad_x:聚合后的隐藏层梯度,对应 dispatch_forward 输入 x 的梯度。
  • combined_grad_recv_topk_weights:聚合后的权重梯度,对应 dispatch_forward 输入 topk_weights 的梯度。
  • eventEventOverlap 对象,用于跟踪聚合操作的 CUDA 事件(需显式同步时使用,如 event.current_stream_wait())。

4. 核心设计思想:前向 - 反向操作对称

MoE 中“调度”与“聚合”是对称操作

  • 前向dispatch_forward 将 token 从“原进程”分发到“专家进程”(多对多通信)。
  • 反向dispatch_backward 通过 combine 将梯度从“专家进程”聚合回“原进程”(多对多通信的逆过程)。

这种设计避免了重复开发反向传播逻辑,直接复用 combine 接口实现梯度聚合,简化了代码并确保通信逻辑一致性。

总结

dispatch_backward 是 MoE 反向传播的关键函数,通过以下方式实现梯度聚合:

  1. 利用“调度前向的反向是聚合”的对称特性,复用 _buffer.combine 接口。
  2. 基于前向调度生成的 handle 元数据,确保梯度按正确路由聚合。
  3. 支持异步执行(async_finish=True),提升反向传播效率。

最终返回与 dispatch_forward 输入对应的梯度(combined_grad_xcombined_grad_recv_topk_weights),完成 MoE 反向传播的梯度链。

5. combine_forward

其核心作用是在 MoE(Mixture of Experts)模型前向传播中执行“结果聚合”逻辑,即将专家处理后的 token 结果(如隐藏层输出)聚合回原进程,生成与输入 dispatch_forward 对应的完整输出。以下是详细解析:

  • 作用:执行 MoE 合并操作。调用 _buffer.combine 方法,将输入数据进行合并,返回合并后的数据和事件对象。

1. 函数定义与核心功能

 
def combine_forward(x: torch.Tensor, handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
 
Tuple[torch.Tensor, EventOverlap]:
 
  • 功能:实现 MoE 前向传播中的“专家结果聚合”,将各专家处理后的 token 数据(x)根据路由信息(handle)聚合回原进程,生成与原始输入 token 数量匹配的输出张量。
  • 应用场景:紧跟在 dispatch_forward 和专家计算之后,是 MoE 前向传播的收尾步骤(例如:将 8 个专家的输出聚合为完整的 batch 结果)。

2. 参数解析

| 参数名 | 类型 | 说明 |

|----------------------|---------------------------------------|----------------------------------------------------------------------|

| x | torch.Tensor | 专家处理后的隐藏层输出(形状与 dispatch_forward 返回的 recv_x 一致,即当前进程本地专家处理的 token 结果)。|

| handle | Tuple | 聚合句柄,由 dispatch_forward 返回,包含 token 路由元数据(如原进程索引、token 在原进程中的位置等),是聚合的“地图”。|

| previous_event | Optional[EventOverlap] | 可选的 CUDA 事件依赖,用于通信 - 计算重叠(如等待专家计算 kernel 完成后再启动聚合通信)。|

3. 关键逻辑步骤

步骤 1:声明全局通信缓冲区

 
global _buffer
 
  • _buffer 是全局 deep_ep.Buffer 实例,提供分布式通信能力(与 dispatch_forward 共用,确保通信资源统一管理)。

步骤 2:调用 combine 执行聚合

 
# 执行 MoE 聚合,将专家输出合并回原进程
 
combined_x, _, event = _buffer.combine(x, handle, async_finish=True, previous_event=previous_event,
 
allocate_on_comm_stream=previous_event is not None)
 
  • 核心操作:通过 _buffer.combine 完成跨进程聚合,具体逻辑为:
  1. 输入数据x 是当前进程本地专家处理后的 token 结果(如形状 (local_tokens, hidden_dim))。
  2. 路由依赖handle 提供聚合所需的元数据(如每个 token 来自哪个原进程、在原进程中的位置索引),确保结果按 dispatch_forward 的逆过程准确聚合。
  3. 异步与性能优化
  • async_finish=True:聚合操作异步执行,CPU 无需等待 GPU 通信完成即可继续后续逻辑(通过 event 跟踪状态)。
  • allocate_on_comm_stream=previous_event is not None:若存在 previous_event(通信 - 计算重叠场景),则在通信流(而非计算流)上分配内存,避免占用计算资源。

步骤 3:返回聚合结果

 
return combined_x, event
 
  • 返回值解析
  • combined_x:聚合后的完整输出张量,形状与 dispatch_forward 的输入 x 一致(如 (num_tokens, hidden_dim)),每个 token 对应其选中专家的加权结果。
  • eventEventOverlap 对象,用于跟踪聚合操作的 CUDA 事件(需显式同步时调用 event.current_stream_wait(),确保 GPU 完成聚合后再使用 combined_x)。

4. 与 dispatch_forward 的关系

MoE 前向传播是“调度 - 计算 - 聚合”的完整流程:

  • dispatch_forward:将输入 token 按 topk_idx 分发到对应专家(“分”)。
  • 专家计算:各专家在本地进程中处理分配的 token(“算”)。
  • combine_forward:将专家处理结果按原路由聚合回原进程(“合”)。

二者通过 handle 关联(dispatch_forward 返回 handlecombine_forward 接收 handle),确保 token 分发与聚合的路由一致性。

总结

combine_forward 是 MoE 前向传播的“最后一公里”,通过以下方式实现专家结果聚合:

  1. 基于 dispatch_forward 生成的 handle 路由元数据,确保聚合的准确性。
  2. 调用 _buffer.combine 完成跨进程通信,支持异步执行和通信 - 计算重叠。
  3. 返回聚合后的完整输出 combined_x 和跟踪事件 event,为后续模型层提供输入。

该函数是专家并行(MoE)中“数据分发 - 处理 - 聚合”闭环的核心组件,直接影响模型输出的正确性和训练/推理效率。

6. combine_backward

其核心作用是实现 MoE(Mixture of Experts)模型中“聚合前向”(combine_forward)的反向传播逻辑,通过将聚合结果的梯度(grad_combined_x)分发回原专家进程,生成与 combine_forward 输入对应的梯度。以下是详细解析:

  • 作用:执行 MoE 合并操作的反向过程,实际上是一个分发操作。调用 _buffer.dispatch 方法,将合并后的梯度数据进行分发,返回分发后的梯度数据和事件对象。

1. 函数定义与核心功能

 
def combine_backward(grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
 
handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
 
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]:
 
  • 功能:在 MoE 反向传播中,将聚合后的梯度(grad_combined_x)分发回原专家进程,生成与 combine_forward 输入 x 对应的梯度(即专家处理前的输入梯度)。
  • 本质:前向传播的“聚合”(combine_forward)将专家结果聚合到原进程,反向传播则需将聚合梯度“分发”回专家进程,因此聚合的反向过程本质是调度操作(见代码注释:“The backward process of MoE combine is actually a dispatch”)。

2. 参数解析

| 参数名 | 类型 | 说明 |

|----------------------|---------------------------------------|----------------------------------------------------------------------|

| grad_combined_x | 张量或张量元组 | 聚合结果 combined_x 的梯度(形状与 combine_forward 返回的 combined_x 一致,如 (num_tokens, hidden_dim))。|

| handle | Tuple | 聚合句柄,由 combine_forward 关联的 dispatch_forward 返回,包含梯度分发所需的元数据(如专家进程索引、token 路由信息等)。|

| previous_event | Optional[EventOverlap] | 可选的 CUDA 事件依赖,用于通信 - 计算重叠(如等待前一个梯度计算 kernel 完成后再启动分发通信)。|

3. 关键逻辑步骤

步骤 1:声明全局通信缓冲区

 
global _buffer
 
  • _buffer 是全局 deep_ep.Buffer 实例,提供分布式通信能力(与 combine_forwarddispatch_forward 共用,确保通信资源统一管理)。

步骤 2:调用 dispatch 分发梯度

 
# 聚合的反向过程本质是调度:将梯度分发回原专家进程
 
grad_x, _, _, _, _, event = _buffer.dispatch(grad_combined_x, handle=handle, async_finish=True,
 
previous_event=previous_event,
 
allocate_on_comm_stream=previous_event is not None)
 
  • 核心操作:通过 _buffer.dispatch 分发梯度,具体逻辑为:
  1. 输入梯度grad_combined_x 是聚合结果的梯度(如模型最终输出对 combine_forward 输出 combined_x 的梯度)。
  2. 句柄依赖handle 提供梯度分发所需的路由元数据(如哪些专家参与了聚合、每个专家对应的 token 范围),确保梯度按 combine_forward 的逆过程准确分发到原专家进程。
  3. 性能优化
  • async_finish=True:梯度分发异步执行,CPU 无需等待 GPU 通信完成即可继续后续逻辑(通过 event 跟踪状态)。
  • allocate_on_comm_stream=previous_event is not None:若存在 previous_event(通信 - 计算重叠场景),则在通信流(而非计算流)上分配内存,避免占用计算资源。

步骤 3:返回梯度与事件

 
return grad_x, event
 
  • 返回值解析
  • grad_x:分发后的梯度,对应 combine_forward 输入 x 的梯度(即专家处理前的输入梯度,需传递给专家的反向传播逻辑)。
  • eventEventOverlap 对象,用于跟踪梯度分发操作的 CUDA 事件(需显式同步时调用 event.current_stream_wait(),确保 GPU 完成分发后再使用 grad_x)。

4. 与 MoE 整体流程的关系

MoE 前向 - 反向传播是“调度 - 聚合 - 聚合反向 - 调度反向”的闭环:

  • 前向dispatch_forward(分发 token 到专家)→ 专家计算 → combine_forward(聚合专家结果)。
  • 反向combine_backward(本函数,分发聚合梯度到专家)→ 专家反向计算 → dispatch_backward(聚合专家梯度到原进程)。

combine_backward 是连接聚合结果梯度与专家梯度的关键环节,确保梯度能从最终输出反向传播到各专家,实现 MoE 模型的端到端训练。

总结

combine_backward 函数通过以下方式实现 MoE 聚合操作的反向传播:

  1. 利用对称特性:基于“聚合的反向是调度”的逻辑,复用 _buffer.dispatch 接口分发梯度。
  2. 依赖路由元数据:通过 handle 确保梯度按前向聚合的逆路径准确分发到原专家进程。
  3. 异步与性能优化:支持异步执行(async_finish=True)和通信流内存分配,提升反向传播效率。

最终返回专家输入对应的梯度 grad_x 和跟踪事件 event,为专家层的反向传播提供输入。


接口用例:inference Decoding TODO