接口用例:model Training or Inference Prefilling
在 README 文档的 ### Example use in model training or inference prefilling 部分,定义了多个函数,这些函数主要用于模型训练或推理预填充阶段,涉及通信缓冲区管理、数据分发和合并等操作。

1. get_buffer
其核心作用是管理并初始化一个全局的通信缓冲区(Buffer),确保缓冲区大小满足分布式通信需求(如 MoE 中的专家路由),并在必要时重新分配缓冲区。以下是详细解释:
- 作用:该函数用于获取或初始化一个
Buffer实例。首先,根据分发和合并操作的配置,计算 NVLink 和 RDMA 所需的缓冲区大小。然后,检查全局变量_buffer是否存在,或者其配置是否满足当前需求。如果不满足,则重新创建一个新的Buffer实例。最后返回这个Buffer实例。
1. 函数定义与参数
def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:
- 功能:获取或创建一个
deep_ep.Buffer实例,用于分布式环境下的专家并行(MoE)通信。 - 参数:
group:torch.distributed.ProcessGroup对象,定义参与通信的进程组(如节点内所有进程)。hidden_bytes:单个 token 的隐藏层数据大小(字节),用于计算缓冲区所需的内存空间。
2. 全局缓冲区变量
global _buffer
_buffer是全局变量,存储唯一的Buffer实例,避免重复创建缓冲区(优化性能)。
3. 计算缓冲区大小需求
# 获取调度(dispatch)和聚合(combine)操作的配置
num_nvl_bytes, num_rdma_bytes = 0, 0
for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())):
# 计算 NVL 缓冲区大小(节点内通信,如 GPU 间共享内存)
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
# 计算 RDMA 缓冲区大小(远程直接内存访问,用于跨节点通信)
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
- 配置获取:
Buffer.get_dispatch_config(group.size()):获取“专家调度”(将 token 分发到对应专家)的配置。Buffer.get_combine_config(group.size()):获取“结果聚合”(将专家输出聚合回原进程)的配置。
(注:配置中包含通信优化参数,如分块大小,可通过测试自动调优,见代码注释)。
- 缓冲区大小计算:
config.get_nvl_buffer_size_hint(…):根据hidden_bytes和进程数(group.size()),估算节点内通信(NVL)所需的最小缓冲区大小(字节)。config.get_rdma_buffer_size_hint(…):类似地,估算跨节点通信(RDMA)所需的最小缓冲区大小。max(…):确保缓冲区大小满足调度和聚合两种操作的最大需求。
4. 缓冲区初始化或复用
# 若缓冲区未创建、进程组不匹配或大小不足,则重新分配
if _buffer is None or _buffer.group != group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes:
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
return _buffer
- 条件判断:仅在以下情况重新创建
Buffer:
_buffer未初始化(首次调用)。- 进程组
group变更(如通信范围改变)。 - 当前缓冲区的 NVL 或 RDMA 空间不足(
num_nvl_bytes/num_rdma_bytes大于现有大小)。
- 创建缓冲区:
Buffer(group, num_nvl_bytes, num_rdma_bytes) 初始化一个新缓冲区,分配 num_nvl_bytes(NVL 内存)和 num_rdma_bytes(RDMA 内存),并绑定到 group 进程组。
- 返回值:返回全局
_buffer实例,确保整个程序中使用同一个缓冲区进行通信。
总结
get_buffer 是缓冲区的“管家”函数,通过以下逻辑确保高效通信:
- 按需分配:根据通信操作(调度/聚合)和数据大小(
hidden_bytes)动态计算所需内存。 - 复用优化:避免重复创建缓冲区,仅在必要时(如大小不足、进程组变更)重新分配。
- 适配分布式:通过
group和配置参数,确保缓冲区兼容当前分布式环境(进程数、通信模式)。
2. get_hidden_bytes
其作用是计算单个 token 的隐藏层数据所占用的字节数,用于后续通信缓冲区大小的估算(如 get_buffer 函数中计算 num_nvl_bytes 和 num_rdma_bytes)。
- 作用:计算输入张量
x的隐藏字节数。如果x是元组,则取元组的第一个元素;否则直接使用x。然后计算该张量第二维的大小,并乘以元素的字节大小(至少为 2),得到隐藏字节数。
1. 函数定义与参数
def get_hidden_bytes(x: torch.Tensor) -> int:
- 功能:计算单个 token 的隐藏层数据大小(字节)。
- 参数:
x是输入的隐藏层张量(或包含隐藏层数据的元组,如 FP8 格式可能包含数据和缩放因子)。
2. 处理输入数据格式
t = x[0] if isinstance(x, tuple) else x
- 逻辑:
- 若
x是元组(如 FP8 格式的隐藏层数据,可能包含(data_tensor, scale_tensor)),则取元组的第一个元素x[0](即实际存储隐藏层数据的张量)。 - 若
x是普通张量(如 BF16/FP32 格式),则直接使用x。 - 目的:统一数据格式,确保后续计算基于隐藏层数据张量本身,而非附加信息(如缩放因子)。
3. 计算字节数
return t.size(1) * max(t.element_size(), 2)
- 分步解析:
t.size(1):获取隐藏层的维度大小(即单个 token 的特征数)。
- 假设
t的形状为(num_tokens, hidden_dim)(批量 token 的隐藏层数据),则t.size(1)对应hidden_dim(如 7168,常见于大语言模型)。
t.element_size():获取张量元素的字节数(如 BF16 为 2 字节,FP32 为 4 字节)。max(t.element_size(), 2):确保每个元素至少按 2 字节计算。
- 原因:即使元素类型小于 2 字节(如理论上的 FP4),通信缓冲区仍需按最小 2 字节对齐(硬件或驱动要求),避免内存访问错误。
- 乘积结果:
hidden_dim * 每个元素字节数,即单个 token 的隐藏层数据总字节数。
示例与场景
- 场景 1:BF16 张量(
x是普通张量) t.element_size() = 2(BF16 占 2 字节),max(2, 2) = 2。- 若
t.size(1) = 7168(隐藏维度),则hidden_bytes = 7168 * 2 = 14336字节/ token。 - 场景 2:FP8 元组(
x是(data_tensor, scale_tensor)) t = x[0](取数据张量),假设t.element_size() = 1(FP8 占 1 字节),max(1, 2) = 2(按 2 字节对齐)。- 若
t.size(1) = 7168,则hidden_bytes = 7168 * 2 = 14336字节/ token。
总结
get_hidden_bytes 是通信缓冲区大小计算的基础工具函数,通过:
- 统一处理不同数据格式(张量/元组),提取核心数据张量;
- 结合隐藏维度和元素字节数(并强制最小 2 字节对齐),计算单个 token 的隐藏层数据字节数。
其结果直接用于 get_buffer 函数,确保分配的缓冲区足以容纳分布式通信中的数据传输需求。
3. dispatch_forward
其核心作用是在 MoE(Mixture of Experts)模型中执行“专家调度”逻辑,即将输入 token 根据 topk_idx(选中的专家索引)分发到对应进程/专家,并返回调度后的结果。以下是详细解析:
- 作用:执行混合专家(MoE)分发操作。首先调用
_buffer.get_dispatch_layout计算分发布局,包括每个进程的令牌数、每个 RDMA 进程的令牌数等。然后调用_buffer.dispatch方法进行实际的分发操作,返回接收的数据、接收的 Top-K 索引和权重、每个专家接收的令牌数列表、句柄和事件对象。
1. 函数定义与核心功能
def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int, previous_event: Optional[EventOverlap] = None) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]:
- 功能:实现 MoE 前向传播中的“token 到专家的路由”,将输入数据
x根据topk_idx分发到对应专家所在的进程,并返回专家处理前的中间数据。 - 应用场景:大语言模型训练或推理的预填充阶段(prefilling),负责专家并行中的跨进程通信调度。
2. 参数解析
| 参数名 | 类型 | 说明 |
|----------------------|---------------------------------------|----------------------------------------------------------------------|
| x | 张量或张量元组 | 输入的隐藏层数据(如 BF16 张量或 FP8 格式的 (data, scale) 元组)。|
| topk_idx | torch.Tensor | 形状 (num_tokens, num_topk),每个 token 选中的 num_topk 个专家索引。|
| topk_weights | torch.Tensor | 形状 (num_tokens, num_topk),选中专家的权重(用于后续加权求和)。|
| num_experts | int | 专家总数(全局)。|
| previous_event | Optional[EventOverlap] | 可选的 CUDA 事件依赖,用于通信 - 计算重叠(如等待前一个 kernel 完成后再调度)。|
3. 关键逻辑步骤
步骤 1:声明全局通信缓冲区
global _buffer
_buffer是全局deep_ep.Buffer实例,负责管理分布式通信所需的内存和通信资源(在get_buffer中初始化)。
步骤 2:计算调度布局(Layout)
# 计算 token 分发的元数据(如每个进程/专家的 token 数量、是否属于当前进程等)
num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = \
_buffer.get_dispatch_layout(topk_idx, num_experts,
previous_event=previous_event, async_finish=True,
allocate_on_comm_stream=previous_event is not None)
- 功能:通过
_buffer.get_dispatch_layout预计算 token 分发的关键元数据,为实际通信做准备。 - 返回值解析:
num_tokens_per_rank:每个进程(rank)需要处理的 token 数量。num_tokens_per_rdma_rank:跨节点通信(RDMA)中每个进程的 token 数量。num_tokens_per_expert:每个专家需要处理的 token 数量。is_token_in_rank:布尔张量,标记每个 token 是否属于当前进程。previous_event:更新后的 CUDA 事件(若传入previous_event,则等待其完成后再继续)。
步骤 3:执行实际调度(Dispatch)
# 执行 MoE 调度,将 token 分发到对应专家/进程
recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event = \
_buffer.dispatch(x, topk_idx=topk_idx, topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event, async_finish=True,
allocate_on_comm_stream=True)
- 核心操作:调用
_buffer.dispatch完成实际的跨进程数据传输,将输入x、topk_idx、topk_weights分发到对应专家所在的进程。 - 关键参数:
async_finish=True:启用异步调度,通信操作不阻塞 CPU(优化吞吐量)。allocate_on_comm_stream=True:在通信流(而非计算流)上分配内存,避免占用计算资源。- 布局元数据(
num_tokens_per_rank等):指导如何高效分发数据。
步骤 4:返回调度结果
return recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, event
- 返回值解析:
recv_x:当前进程接收到的、需要由本地专家处理的 token 数据(格式与输入x一致,如张量或 FP8 元组)。recv_topk_idx/recv_topk_weights:分发后当前进程本地专家对应的专家索引和权重。num_recv_tokens_per_expert_list:列表,每个元素为当前进程中每个本地专家接收到的 token 数量。handle:调度句柄(元组),包含后续聚合(combine)所需的元数据(如 token 前缀矩阵)。event:EventOverlap对象,用于跟踪调度操作的 CUDA 事件(便于后续同步或重叠)。
4. 关键特性与注意事项
- 通信 - 计算重叠:通过
previous_event可指定依赖的 CUDA 事件(如等待前一个计算 kernel 完成),实现通信与计算的并行。 - 异步执行:
async_finish=True使调度操作异步执行,CPU 无需等待 GPU 通信完成即可继续后续逻辑(需通过event显式同步)。 - CUDA 图兼容性:注释提到“CPU 会等待 GPU 信号”,因此默认不兼容 CUDA 图(除非指定
num_worst_tokens,但仅限节点内场景)。
总结
dispatch_forward 是 MoE 前向传播的核心函数,通过以下流程实现 token 到专家的高效分发:
- 布局计算:预计算 token 分布 metadata(进程/专家的 token 数量)。
- 实际调度:基于 metadata 调用缓冲区通信接口,跨进程传输数据。
- 返回结果:提供本地专家所需的输入数据、元数据及事件句柄,为后续专家计算和结果聚合(
combine_forward)做准备。
其设计优化了分布式通信效率,支持异步/重叠操作,是 DeepEP 库实现高效专家并行的关键组件。
4. dispatch_backward
其核心作用是实现 MoE(Mixture of Experts)模型中“调度前向”(dispatch_forward)的反向传播逻辑,通过聚合各专家返回的梯度,生成原始输入的梯度。以下是详细解析:
- 作用:执行 MoE 分发操作的反向过程,实际上是一个合并操作。调用
_buffer.combine方法,将梯度数据和 Top-K 权重梯度进行合并,返回合并后的梯度数据、合并后的 Top-K 权重梯度和事件对象。
1. 函数定义与核心功能
def dispatch_backward(grad_recv_x: torch.Tensor, grad_recv_topk_weights: torch.Tensor, handle: Tuple) -> \
Tuple[torch.Tensor, torch.Tensor, EventOverlap]:
- 功能:在 MoE 反向传播中,将专家处理后的梯度(
grad_recv_x、grad_recv_topk_weights)聚合回原进程,生成与dispatch_forward输入对应的梯度(如原始x和topk_weights的梯度)。 - 本质:前向传播的“调度”(
dispatch_forward)将 token 分发到专家,反向传播则需将专家的梯度“聚合”回原进程,因此调度的反向过程本质是聚合操作(见代码注释:“The backward process of MoE dispatch is actually a combine”)。
2. 参数解析
| 参数名 | 类型 | 说明 |
|-------------------------|-----------------------|----------------------------------------------------------------------|
| grad_recv_x | torch.Tensor | 专家处理后返回的隐藏层梯度(形状与 dispatch_forward 返回的 recv_x 一致)。|
| grad_recv_topk_weights| torch.Tensor | 专家处理后返回的 topk_weights 梯度(形状与 dispatch_forward 返回的 recv_topk_weights 一致)。|
| handle | Tuple | 调度句柄,由 dispatch_forward 返回,包含聚合所需的元数据(如 token 路由信息、进程前缀矩阵等)。|
3. 关键逻辑步骤
步骤 1:声明全局通信缓冲区
global _buffer
_buffer是全局deep_ep.Buffer实例,负责管理分布式通信资源(与dispatch_forward共用,避免重复初始化)。
步骤 2:调用 combine 聚合梯度
combined_grad_x, combined_grad_recv_topk_weights, event = \
_buffer.combine(grad_recv_x, handle, topk_weights=grad_recv_topk_weights, async_finish=True)
- 核心操作:通过
_buffer.combine聚合梯度,具体逻辑为:
- 输入梯度:
grad_recv_x(隐藏层梯度)和grad_recv_topk_weights(权重梯度)是专家处理后返回的局部梯度。 - 句柄依赖:
handle提供聚合所需的路由元数据(如哪些 token 来自哪个进程),确保梯度按前向调度的逆过程聚合。 - 异步执行:
async_finish=True使聚合操作异步执行(CPU 无需等待 GPU 通信完成,通过event跟踪状态)。
步骤 3:返回聚合结果
return combined_grad_x, combined_grad_recv_topk_weights, event
- 返回值解析:
combined_grad_x:聚合后的隐藏层梯度,对应dispatch_forward输入x的梯度。combined_grad_recv_topk_weights:聚合后的权重梯度,对应dispatch_forward输入topk_weights的梯度。event:EventOverlap对象,用于跟踪聚合操作的 CUDA 事件(需显式同步时使用,如event.current_stream_wait())。
4. 核心设计思想:前向 - 反向操作对称
MoE 中“调度”与“聚合”是对称操作:
- 前向:
dispatch_forward将 token 从“原进程”分发到“专家进程”(多对多通信)。 - 反向:
dispatch_backward通过combine将梯度从“专家进程”聚合回“原进程”(多对多通信的逆过程)。
这种设计避免了重复开发反向传播逻辑,直接复用 combine 接口实现梯度聚合,简化了代码并确保通信逻辑一致性。
总结
dispatch_backward 是 MoE 反向传播的关键函数,通过以下方式实现梯度聚合:
- 利用“调度前向的反向是聚合”的对称特性,复用
_buffer.combine接口。 - 基于前向调度生成的
handle元数据,确保梯度按正确路由聚合。 - 支持异步执行(
async_finish=True),提升反向传播效率。
最终返回与 dispatch_forward 输入对应的梯度(combined_grad_x、combined_grad_recv_topk_weights),完成 MoE 反向传播的梯度链。
5. combine_forward
其核心作用是在 MoE(Mixture of Experts)模型前向传播中执行“结果聚合”逻辑,即将专家处理后的 token 结果(如隐藏层输出)聚合回原进程,生成与输入 dispatch_forward 对应的完整输出。以下是详细解析:
- 作用:执行 MoE 合并操作。调用
_buffer.combine方法,将输入数据进行合并,返回合并后的数据和事件对象。
1. 函数定义与核心功能
def combine_forward(x: torch.Tensor, handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
Tuple[torch.Tensor, EventOverlap]:
- 功能:实现 MoE 前向传播中的“专家结果聚合”,将各专家处理后的 token 数据(
x)根据路由信息(handle)聚合回原进程,生成与原始输入 token 数量匹配的输出张量。 - 应用场景:紧跟在
dispatch_forward和专家计算之后,是 MoE 前向传播的收尾步骤(例如:将 8 个专家的输出聚合为完整的 batch 结果)。
2. 参数解析
| 参数名 | 类型 | 说明 |
|----------------------|---------------------------------------|----------------------------------------------------------------------|
| x | torch.Tensor | 专家处理后的隐藏层输出(形状与 dispatch_forward 返回的 recv_x 一致,即当前进程本地专家处理的 token 结果)。|
| handle | Tuple | 聚合句柄,由 dispatch_forward 返回,包含 token 路由元数据(如原进程索引、token 在原进程中的位置等),是聚合的“地图”。|
| previous_event | Optional[EventOverlap] | 可选的 CUDA 事件依赖,用于通信 - 计算重叠(如等待专家计算 kernel 完成后再启动聚合通信)。|
3. 关键逻辑步骤
步骤 1:声明全局通信缓冲区
global _buffer
_buffer是全局deep_ep.Buffer实例,提供分布式通信能力(与dispatch_forward共用,确保通信资源统一管理)。
步骤 2:调用 combine 执行聚合
# 执行 MoE 聚合,将专家输出合并回原进程
combined_x, _, event = _buffer.combine(x, handle, async_finish=True, previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None)
- 核心操作:通过
_buffer.combine完成跨进程聚合,具体逻辑为:
- 输入数据:
x是当前进程本地专家处理后的 token 结果(如形状(local_tokens, hidden_dim))。 - 路由依赖:
handle提供聚合所需的元数据(如每个 token 来自哪个原进程、在原进程中的位置索引),确保结果按dispatch_forward的逆过程准确聚合。 - 异步与性能优化:
async_finish=True:聚合操作异步执行,CPU 无需等待 GPU 通信完成即可继续后续逻辑(通过event跟踪状态)。allocate_on_comm_stream=previous_event is not None:若存在previous_event(通信 - 计算重叠场景),则在通信流(而非计算流)上分配内存,避免占用计算资源。
步骤 3:返回聚合结果
return combined_x, event
- 返回值解析:
combined_x:聚合后的完整输出张量,形状与dispatch_forward的输入x一致(如(num_tokens, hidden_dim)),每个 token 对应其选中专家的加权结果。event:EventOverlap对象,用于跟踪聚合操作的 CUDA 事件(需显式同步时调用event.current_stream_wait(),确保 GPU 完成聚合后再使用combined_x)。
4. 与 dispatch_forward 的关系
MoE 前向传播是“调度 - 计算 - 聚合”的完整流程:
dispatch_forward:将输入 token 按topk_idx分发到对应专家(“分”)。- 专家计算:各专家在本地进程中处理分配的 token(“算”)。
combine_forward:将专家处理结果按原路由聚合回原进程(“合”)。
二者通过 handle 关联(dispatch_forward 返回 handle,combine_forward 接收 handle),确保 token 分发与聚合的路由一致性。
总结
combine_forward 是 MoE 前向传播的“最后一公里”,通过以下方式实现专家结果聚合:
- 基于
dispatch_forward生成的handle路由元数据,确保聚合的准确性。 - 调用
_buffer.combine完成跨进程通信,支持异步执行和通信 - 计算重叠。 - 返回聚合后的完整输出
combined_x和跟踪事件event,为后续模型层提供输入。
该函数是专家并行(MoE)中“数据分发 - 处理 - 聚合”闭环的核心组件,直接影响模型输出的正确性和训练/推理效率。
6. combine_backward
其核心作用是实现 MoE(Mixture of Experts)模型中“聚合前向”(combine_forward)的反向传播逻辑,通过将聚合结果的梯度(grad_combined_x)分发回原专家进程,生成与 combine_forward 输入对应的梯度。以下是详细解析:
- 作用:执行 MoE 合并操作的反向过程,实际上是一个分发操作。调用
_buffer.dispatch方法,将合并后的梯度数据进行分发,返回分发后的梯度数据和事件对象。
1. 函数定义与核心功能
def combine_backward(grad_combined_x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
handle: Tuple, previous_event: Optional[EventOverlap] = None) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], EventOverlap]:
- 功能:在 MoE 反向传播中,将聚合后的梯度(
grad_combined_x)分发回原专家进程,生成与combine_forward输入x对应的梯度(即专家处理前的输入梯度)。 - 本质:前向传播的“聚合”(
combine_forward)将专家结果聚合到原进程,反向传播则需将聚合梯度“分发”回专家进程,因此聚合的反向过程本质是调度操作(见代码注释:“The backward process of MoE combine is actually a dispatch”)。
2. 参数解析
| 参数名 | 类型 | 说明 |
|----------------------|---------------------------------------|----------------------------------------------------------------------|
| grad_combined_x | 张量或张量元组 | 聚合结果 combined_x 的梯度(形状与 combine_forward 返回的 combined_x 一致,如 (num_tokens, hidden_dim))。|
| handle | Tuple | 聚合句柄,由 combine_forward 关联的 dispatch_forward 返回,包含梯度分发所需的元数据(如专家进程索引、token 路由信息等)。|
| previous_event | Optional[EventOverlap] | 可选的 CUDA 事件依赖,用于通信 - 计算重叠(如等待前一个梯度计算 kernel 完成后再启动分发通信)。|
3. 关键逻辑步骤
步骤 1:声明全局通信缓冲区
global _buffer
_buffer是全局deep_ep.Buffer实例,提供分布式通信能力(与combine_forward、dispatch_forward共用,确保通信资源统一管理)。
步骤 2:调用 dispatch 分发梯度
# 聚合的反向过程本质是调度:将梯度分发回原专家进程
grad_x, _, _, _, _, event = _buffer.dispatch(grad_combined_x, handle=handle, async_finish=True,
previous_event=previous_event,
allocate_on_comm_stream=previous_event is not None)
- 核心操作:通过
_buffer.dispatch分发梯度,具体逻辑为:
- 输入梯度:
grad_combined_x是聚合结果的梯度(如模型最终输出对combine_forward输出combined_x的梯度)。 - 句柄依赖:
handle提供梯度分发所需的路由元数据(如哪些专家参与了聚合、每个专家对应的 token 范围),确保梯度按combine_forward的逆过程准确分发到原专家进程。 - 性能优化:
async_finish=True:梯度分发异步执行,CPU 无需等待 GPU 通信完成即可继续后续逻辑(通过event跟踪状态)。allocate_on_comm_stream=previous_event is not None:若存在previous_event(通信 - 计算重叠场景),则在通信流(而非计算流)上分配内存,避免占用计算资源。
步骤 3:返回梯度与事件
return grad_x, event
- 返回值解析:
grad_x:分发后的梯度,对应combine_forward输入x的梯度(即专家处理前的输入梯度,需传递给专家的反向传播逻辑)。event:EventOverlap对象,用于跟踪梯度分发操作的 CUDA 事件(需显式同步时调用event.current_stream_wait(),确保 GPU 完成分发后再使用grad_x)。
4. 与 MoE 整体流程的关系
MoE 前向 - 反向传播是“调度 - 聚合 - 聚合反向 - 调度反向”的闭环:
- 前向:
dispatch_forward(分发 token 到专家)→ 专家计算 →combine_forward(聚合专家结果)。 - 反向:
combine_backward(本函数,分发聚合梯度到专家)→ 专家反向计算 →dispatch_backward(聚合专家梯度到原进程)。
combine_backward 是连接聚合结果梯度与专家梯度的关键环节,确保梯度能从最终输出反向传播到各专家,实现 MoE 模型的端到端训练。
总结
combine_backward 函数通过以下方式实现 MoE 聚合操作的反向传播:
- 利用对称特性:基于“聚合的反向是调度”的逻辑,复用
_buffer.dispatch接口分发梯度。 - 依赖路由元数据:通过
handle确保梯度按前向聚合的逆路径准确分发到原专家进程。 - 异步与性能优化:支持异步执行(
async_finish=True)和通信流内存分配,提升反向传播效率。
最终返回专家输入对应的梯度 grad_x 和跟踪事件 event,为专家层的反向传播提供输入。