[(5 条消息) 图解大模型训练系列之:DeepSpeed-Megatron MoE 并行训练(原理篇) - 知乎](图解大模型训练系列之:DeepSpeed-Megatron MoE并行训练(原理篇))
Pytest 设置
test_token_dispatcher.py 中 TestFlexDispatcher
1. TestFlexDispatcher 测试类结构
TestFlexDispatcher 是针对 Flex 令牌调度器 的测试类,使用 pytest 框架组织测试用例。核心测试方法包括:
-
test_forward_backward:验证前向/反向传播正确性 -
test_capacity_forward_backward:验证专家容量控制逻辑 -
test_router_padding_for_fp8_forward_backward:验证 FP8 训练时的路由填充逻辑
2. pytest 核心机制应用
(1)测试前置/后置处理
class TestFlexDispatcher:
def setup_method(self, method):
pass # 测试方法执行前的初始化(当前为空)
def teardown_method(self, method):
Utils.destroy_model_parallel() # 测试后清理模型并行状态
-
setup_method:每个测试方法执行前调用(此处未使用,可用于初始化共享资源) -
teardown_method:每个测试方法执行后调用,确保模型并行状态正确销毁,避免测试污染
(2)条件跳过测试
@pytest.mark.skipif(not is_deep_ep_available(), reason="Deep EP is not available")
class TestFlexDispatcher: ...
- 使用
@pytest.mark.skipif标记类级别跳过条件:当 Deep EP 不可用时,整个TestFlexDispatcher类的测试会被跳过
(3)参数化测试(核心)
通过 @pytest.mark.parametrize 实现多场景测试,自动生成测试用例组合:
@pytest.mark.parametrize("tp_size,ep_size", [(8, 1), (1, 8), (2, 4)])
@pytest.mark.parametrize("permute_fusion", permute_fusion_params) # [False, True](取决于TE版本)
@pytest.mark.parametrize("experimental_fusion", [True, False])
def test_forward_backward(self, tp_size, ep_size, permute_fusion, experimental_fusion):
...
-
参数组合逻辑:3(tp_size/ep_size)× 2(permute_fusion)× 2(experimental_fusion)= 12 个测试用例
-
动态参数来源:
permute_fusion_params根据 Transformer Engine 版本动态生成([False]或[False, True])
3. 测试流程与输入输出验证
以 test_forward_backward 为例,完整测试流程如下:
(1)测试环境准备
container = MoEModelTestContainer(
tp_size=tp_size,
ep_size=ep_size,
pp_size=1,
num_moe_experts=8,
moe_token_dispatcher_type="flex", # 指定 Flex 调度器
moe_permute_fusion=permute_fusion,
experimental_fusion=experimental_fusion, # 控制实验性融合开关
...
)
-
MoEModelTestContainer是测试容器类,封装了 MoE 层初始化、并行环境配置等逻辑 -
通过构造参数控制测试场景(如张量并行大小
tp_size、专家并行大小ep_size等)
(2)执行测试用例
container.dispatcher_dropless_test() # 调用具体测试逻辑
dispatcher_dropless_test 方法实现核心验证逻辑:
-
输入:随机生成的
hidden_states(形状[bs, seql, hidden_size]) -
处理流程:
-
路由计算:
probs, indices = moe_layer.router(hidden_states) -
令牌调度:
token_permutation(分发→排列→后处理) -
令牌反调度:
token_unpermutation(合并→恢复)
- 输出验证:
# 验证恢复的 hidden_states 与原始输入一致
assert torch.allclose(restored_hidden_states, ans)
# 验证反向传播梯度正确性
torch.autograd.backward(restored_hidden_states, hidden_states)
assert torch.allclose(hidden_states.grad, ans)
4. 关键测试场景说明
| 测试方法 | 核心验证目标 | 输入特点 | 输出验证指标 |
| ---------------------------------- | ------------------------------ | ------------------- | -------------------- |
| test_forward_backward | Flex 调度器基本前向/反向传播正确性 | 随机 hidden_states | 恢复的特征、梯度与原始输入一致 |
| test_capacity_forward_backward | 专家容量控制(令牌数量不超过专家容量) | 固定 num_tokens=16 | 令牌数 ≤ 专家容量,输出特征正确性 |
| test_router_padding_for_fp8_test | FP8 训练时路由映射填充逻辑(令牌数对齐到 16 的倍数) | num_tokens=32 | 填充前后输出一致,令牌数为 16 的倍数 |
5. pytest 执行与结果呈现
-
用例发现:pytest 自动识别
TestFlexDispatcher类中以test_开头的方法 -
测试报告:每个参数组合生成独立的测试结果(如
test_forward_backward[8-1-True-True]) -
失败处理:若
assert断言失败(如输出特征不一致),pytest 会显示详细的张量差异信息
总结
该文件通过 pytest 的 参数化测试 实现了对 Flex 令牌调度器多场景覆盖,结合 测试容器类 封装环境配置,通过 输入随机化 和 输出断言 验证核心功能正确性。这种设计确保了测试的可扩展性(新增参数组合即可覆盖新场景)和可靠性(自动清理环境 + 独立测试用例)。
使用 Pytest 测试 TestFlexDispatcher
一、单独测试 TestFlexDispatcher 的命令
TestFlexDispatcher 是针对 Flex 令牌调度器 的测试类,使用 pytest 单独运行其测试的命令如下:
pytest tests/unit_tests/transformer/moe/test_token_dispatcher.py -k "TestFlexDispatcher" -v
参数说明:
-
tests/unit_tests/transformer/moe/test_token_dispatcher.py:指定测试文件路径。 -
-k "TestFlexDispatcher":通过关键字匹配,仅运行类名包含TestFlexDispatcher的测试(过滤其他测试类,如TestAllgatherDispatcher)。 -
-v:显示详细测试日志(每个子用例的执行结果)。
二、pytest 核心语法在文件中的应用
1. 测试类定义与生命周期管理
@pytest.mark.skipif(not is_deep_ep_available(), reason="Deep EP is not available")
class TestFlexDispatcher:
def setup_method(self, method):
pass # 测试方法执行前的初始化(当前为空,可用于创建共享资源)
def teardown_method(self, method):
Utils.destroy_model_parallel() # 测试后清理模型并行状态,避免影响其他测试
@pytest.mark.skipif(类级别跳过):
条件判断是否跳过整个类。例如,若 Deep EP 库不可用(not is_deep_ep_available()),则 TestFlexDispatcher 所有测试方法均不执行。
-
setup_method/teardown_method: -
setup_method:每个测试方法(如test_forward_backward)执行前调用,用于初始化测试环境(如创建临时文件、初始化模型)。 -
teardown_method:每个测试方法执行后调用,用于清理资源(如销毁分布式环境、释放 GPU 内存)。
2. 测试方法与参数化(核心功能)
以 test_capacity_forward_backward 为例, pytest 通过 参数化 实现多场景自动测试:
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.internal # 自定义标记(可能用于分类测试,如内部测试/公开测试)
@pytest.mark.timeout(120) # 超时控制:测试超过 120 秒则失败
@pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2)]) # 参数组合1
@pytest.mark.parametrize("permute_fusion", permute_fusion_params) # 参数组合2(动态生成)
@pytest.mark.parametrize("experimental_fusion", [True, False]) # 参数组合3
def test_capacity_forward_backward(self, tp_size, ep_size, permute_fusion, experimental_fusion):
# 测试逻辑:初始化容器 → 执行测试 → 验证结果
container = MoEModelTestContainer(tp_size=tp_size, ep_size=ep_size, ...)
container.dispatcher_capacity_test() # 核心测试逻辑
关键语法解析
@pytest.mark.parametrize(参数化测试):
自动生成多组测试用例,每组参数对应一个独立的测试。例如上述代码中:
-
tp_size,ep_size有 3 种组合:(1,8)、(8,1)、(4,2)。 -
permute_fusion由permute_fusion_params动态生成(如[False, True],取决于 Transformer Engine 版本)。 -
experimental_fusion有 2 种组合:True/False。
→ 总测试用例数:3(tp/ep) × 2(permute) × 2(experimental) = 12 个,pytest 会自动运行所有组合并报告每个用例的结果。
@pytest.mark.timeout(120):
设置单个测试方法的超时时间(120 秒),防止测试因死锁或性能问题无限期阻塞。
@pytest.mark.skipif(方法级别跳过):
比类级别更细粒度的跳过条件。例如,若没有 CUDA 环境(not torch.cuda.is_available()),则跳过该测试方法。
3. 测试逻辑与输入输出验证
测试方法的核心流程是:构造输入 → 执行待测试逻辑 → 验证输出。以 dispatcher_capacity_test 为例(被 test_capacity_forward_backward 调用):
输入构造
# 生成随机输入张量(模拟模型中间层输出)
hidden_states = torch.randn((num_tokens, hidden_size)).cuda()
hidden_states.requires_grad = True # 启用梯度计算,用于反向传播测试
核心逻辑执行
# 路由计算(专家选择概率和索引)
probs, indices = moe_layer.router(hidden_states)
# 令牌调度(分发到专家)
permuted_input, tokens_per_expert, permuted_probs = token_permutation(...)
# 令牌合并(从专家收集结果)
restored_hidden_states = token_unpermutation(...)
输出验证(断言)
# 验证专家接收的令牌数不超过容量上限
assert torch.all(tokens_per_expert <= capacity), "令牌数超过专家容量"
# 验证反向传播梯度正确性(张量接近度比较,允许微小误差)
assert torch.allclose(hidden_states.grad, expected_grad), "梯度计算错误"
torch.allclose:由于浮点数计算误差,不直接用==比较张量,而是检查是否在允许的误差范围内(如rtol=1e-5,atol=1e-8)。
三、常用 Pytest 命令扩展
| 命令示例 | 作用 |
|----------|------|
| pytest -k TestFlexDispatcher -v | 仅运行 TestFlexDispatcher 类,显示详细日志 |
| pytest -s | 显示测试中的 print 输出(默认 pytest 会捕获输出) |
| pytest --lf | 仅运行上次失败的测试(快速复现问题) |
| pytest --cov=megatron.core.transformer.moe | 生成测试覆盖率报告(需安装 pytest-cov) |
总结
TestFlexDispatcher 通过 pytest 的 参数化、条件跳过、生命周期管理 和 断言验证,实现了对 Flex 令牌调度器多场景(不同并行策略、融合开关、容量配置)的自动化测试。核心步骤是:
-
用
@pytest.mark.parametrize定义测试参数组合; -
用
MoEModelTestContainer构造测试环境; -
调用具体测试方法(如
dispatcher_capacity_test)执行逻辑; -
用
assert+torch.allclose验证输出和梯度正确性。
TestFlexDispatcher 中三种测试方法的 MoE 场景与参数设计解析
一、test_forward_backward:基础前向/反向传播功能测试
测试目标:验证 Flex 令牌调度器在 无令牌丢弃(dropless)场景 下的核心功能正确性,包括令牌分发、合并及梯度传播。
核心验证逻辑
通过 dispatcher_dropless_test 实现:
-
生成随机输入
hidden_states,经 MoE 路由(router)计算专家选择概率(probs)和索引(indices)。 -
调用
token_permutation(令牌分发)和token_unpermutation(令牌合并),验证合并后的restored_hidden_states与原始输入一致。 -
验证反向传播梯度(
hidden_states.grad)与原始输入的一致性,确保梯度计算未被调度逻辑破坏。
参数设计原因
@pytest.mark.parametrize("tp_size,ep_size", [(8, 1), (1, 8), (2, 4)]) # 并行策略组合
@pytest.mark.parametrize("permute_fusion", permute_fusion_params) # 融合开关
@pytest.mark.parametrize("experimental_fusion", [True, False]) # 实验性融合开关
def test_forward_backward(...): ...
-
tp_size/ep_size: 测试不同并行策略组合(张量并行/专家并行),确保调度器在模型并行拆分时仍能正确路由令牌。例如: -
(8, 1):8 路张量并行,专家不并行(所有专家在单个 GPU)。 -
(1, 8):专家并行(8 个专家分布在 8 个 GPU),无张量并行。 -
(2, 4):混合并行(2 路张量并行 + 4 路专家并行)。 -
permute_fusion/experimental_fusion: 测试调度器与融合优化(如算子融合、内存优化)的兼容性,确保功能正确性不受性能优化影响。
二、test_capacity_forward_backward:专家容量控制测试
测试目标:验证 Flex 调度器在 专家容量限制场景 下的令牌丢弃/截断逻辑,确保专家接收的令牌数不超过其容量上限。
核心验证逻辑
通过 dispatcher_capacity_test 实现:
-
计算专家容量(
capacity):基于moe_expert_capacity_factor(容量系数)、令牌总数和专家数量。 -
验证
tokens_per_expert(每个专家实际接收的令牌数)是否小于等于capacity * ep_size * tp_size(考虑并行拆分后的实际容量)。 -
验证令牌合并后的输出和梯度与预期结果一致,确保容量控制未破坏计算正确性。
参数设计原因
container = MoEModelTestContainer(
moe_expert_capacity_factor=0.5, # 容量系数(控制专家最大令牌数)
moe_token_drop_policy="probs", # 令牌丢弃策略(按概率丢弃超额令牌)
hidden_size=4, # 小隐藏维度加速测试
...
)
-
moe_expert_capacity_factor=0.5: 将专家容量设为默认值的 50%,刻意降低容量以触发令牌丢弃逻辑,验证调度器是否能正确截断超额令牌。 -
moe_token_drop_policy="probs": 测试按概率丢弃令牌的策略(而非随机丢弃),确保高概率令牌优先保留。 -
hidden_size=4: 使用小维度输入加速测试,减少计算资源消耗。
三、test_router_padding_for_fp8_test:FP8 训练场景下的路由填充测试
测试目标:验证 Flex 调度器在 FP8 混合精度训练 时的令牌填充逻辑,确保专家输入长度对齐到 16 的倍数(FP8 kernels 要求),且填充不影响计算结果。
核心验证逻辑
通过 dispatcher_router_padding_for_fp8_test 实现:
-
对比两种模式:
moe_router_padding_for_fp8=False(无填充)和True(填充到 16 的倍数)。 -
验证填充模式下
tokens_per_expert为 16 的倍数(tokens_per_expert_2 % 16 == 0)。 -
验证两种模式的输出
restored_hidden_states和梯度完全一致,确保填充未引入计算误差。
参数设计原因
container = MoEModelTestContainer(
num_moe_experts=32, # 更多专家,增加令牌分布随机性
moe_router_topk=4, # 更高 top-k,增加每个令牌选择的专家数
fp8="hybrid", # 启用 FP8 混合精度
...
)
-
num_moe_experts=32/moe_router_topk=4: 增加专家数量和每个令牌选择的专家数,使令牌分布更分散,更容易出现非 16 倍数的令牌数,从而测试填充逻辑的鲁棒性。 -
fp8="hybrid": 显式启用 FP8 模式,确保测试环境与实际 FP8 训练一致。 -
@pytest.mark.skipif(not is_te_min_version("1.7.0")): FP8 路由填充依赖 Transformer Engine 1.7.0+ 的特性,低版本 TE 不支持,故跳过。
总结:三种测试的互补性
| 测试方法 | 核心场景 | 关键参数设计逻辑 |
|-----------------------------------|---------------------------|--------------------------------------------------------------------------------|
| test_forward_backward | 基础功能验证(无丢弃) | 多并行策略组合,验证调度器在不同模型拆分下的正确性。 |
| test_capacity_forward_backward | 专家容量控制(有丢弃) | 低容量系数触发丢弃逻辑,验证令牌截断和专家负载均衡。 |
| test_router_padding_for_fp8_test| FP8 训练兼容性(需填充) | 更多专家/更高 top-k 增加填充必要性,验证 FP8 kernels 所需的输入对齐要求。 |
三者共同覆盖了 MoE 调度器的 基础功能、容量控制、精度兼容性 三大核心场景,参数设计均针对场景痛点(如并行拆分、容量超限、FP8 对齐),确保测试的全面性和有效性。
基础测试流程
TestFlexDispatcher.test_forward_backward 调用流程与模型配置解析
1. 测试方法定义与参数化(pytest 驱动)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") # 条件跳过:无CUDA时不执行
@pytest.mark.internal # 内部测试标记(用于分类测试)
@pytest.mark.parametrize("tp_size,ep_size", [(8, 1), (1, 8), (2, 4)]) # 参数组合1:张量并行/专家并行策略
@pytest.mark.parametrize("permute_fusion", permute_fusion_params) # 参数组合2:算子融合开关(True/False)
@pytest.mark.parametrize("experimental_fusion", [True, False]) # 参数组合3:实验性融合开关
def test_forward_backward(self, tp_size, ep_size, permute_fusion, experimental_fusion): # 参数化输入
作用:
-
@pytest.mark.parametrize生成多组测试用例(如tp_size=8,ep_size=1或tp_size=2,ep_size=4),覆盖不同并行策略和融合配置。 -
tp_size/ep_size:控制模型并行拆分方式(张量并行/专家并行),验证 Flex 调度器在不同并行场景下的兼容性。
2. 实验性融合开关配置
if experimental_fusion:
config.ENABLE_EXPERIMENTAL = True # 启用实验性融合(如内存优化或算子合并)
作用:
- 动态控制全局配置
config.ENABLE_EXPERIMENTAL,测试调度器与实验性优化的兼容性。
3. 创建 MoE 测试容器(核心配置入口)
container = MoEModelTestContainer(
tp_size=tp_size, # 张量并行大小(来自参数化输入)
ep_size=ep_size, # 专家并行大小(来自参数化输入)
pp_size=1, # 流水线并行大小(固定为1,无需测试)
num_moe_experts=8, # 总专家数:8个专家
moe_router_topk=2, # 路由选择Top-2专家(每个令牌分配给2个专家)
moe_router_load_balancing_type="aux_loss", # 负载均衡策略:辅助损失函数
moe_token_dispatcher_type="flex", # 令牌调度器类型:Flex(测试目标)
moe_permute_fusion=permute_fusion, # 排列融合开关(来自参数化输入)
hidden_size=4, # 隐藏层维度(小维度加速测试)
moe_enable_deepep=True, # 启用Deep EP优化(Flex调度器依赖)
)
关键配置解析:
-
并行策略:
tp_size(张量并行)和ep_size(专家并行)决定模型拆分方式。例如,tp_size=8,ep_size=1表示 8 路张量并行、专家不并行;ep_size=8表示 8 个专家分布在 8 个 GPU。 -
MoE 核心参数:
-
num_moe_experts=8:总专家数为 8,与ep_size共同决定每个 GPU 的本地专家数(num_local_experts = 8 // ep_size)。 -
moe_router_topk=2:每个令牌路由到 2 个专家,验证多专家分配场景。 -
moe_token_dispatcher_type="flex":显式指定使用 Flex 调度器(测试核心目标)。
4. MoEModelTestContainer 初始化流程(模型环境搭建)
MoEModelTestContainer 的 __init__ 方法完成以下关键步骤:
(1)模型并行环境初始化
Utils.initialize_model_parallel(
tensor_model_parallel_size=tp_size,
expert_model_parallel_size=ep_size,
... # 其他并行参数
)
作用:通过 Utils 工具类设置分布式环境,包括张量并行组、专家并行组的初始化,确保各 GPU 间通信正常。
(2)生成 TransformerConfig(模型配置核心)
self.config = TransformerConfig(
tensor_model_parallel_size=tp_size,
expert_model_parallel_size=ep_size,
moe_token_dispatcher_type="flex", # 绑定Flex调度器
hidden_size=4, # 小隐藏维度(减少计算量)
... # 其他MoE配置(如容量系数、融合开关)
)
作用:创建 TransformerConfig 对象,集中管理模型架构、并行策略、MoE 特性等配置,后续传递给 MoELayer。
(3)初始化 MoE 层
self.moe_layer = self.new_moe_layer() # 调用new_moe_layer创建MoE层
new_moe_layer 方法细节:
def new_moe_layer(self, **kargs):
transformer_layer_spec = get_gpt_layer_local_spec(...) # 获取GPT层规格(含MLP子模块)
new_config = dataclasses.replace(self.config, **kargs) # 合并配置(支持动态修改)
moe_layer = MoELayer(new_config, transformer_layer_spec.submodules.mlp.submodules).cuda() # 创建MoE层并移至GPU
return moe_layer
作用:
-
MoELayer是核心 MoE 组件,包含路由模块(router)和 Flex 令牌调度器(token_dispatcher)。 -
transformer_layer_spec指定 MLP 子模块结构,确保专家网络符合 GPT 架构。
5. 执行测试逻辑:dispatcher_dropless_test
container.dispatcher_dropless_test() # 调用无令牌丢弃场景的测试
该方法是实际测试逻辑的入口,内部流程如下:
(1)生成测试输入
bs = 32 # batch size
seql = 8 # 序列长度
hidden_states = torch.randn((bs, seql, moe_layer.config.hidden_size)).cuda() # 随机输入张量 [32,8,4]
hidden_states.requires_grad = True # 启用梯度计算(用于反向传播测试)
(2)路由计算(专家选择)
probs, indices = moe_layer.router(hidden_states) # 路由输出:专家选择概率(probs)和索引(indices)
probs = torch.ones_like(probs) / moe_layer.router.topk # 归一化概率(等概率选择Top-k专家)
(3)Flex 调度器令牌分发与合并
# 令牌分发:将hidden_states按indices分配给对应专家
permuted_local_hidden_states, tokens_per_expert, permuted_probs = token_permutation(
moe_layer.token_dispatcher, hidden_states, probs, indices
)
# 令牌合并:从专家收集结果并恢复原始顺序
restored_hidden_states, restored_bias = token_unpermutation(
moe_layer.token_dispatcher, permuted_local_hidden_states
)
(4)结果验证(断言)
# 验证前向传播:恢复的hidden_states与原始输入一致
assert torch.allclose(restored_hidden_states, ans), "前向传播结果不匹配"
# 验证反向传播:梯度与原始输入一致
torch.autograd.backward(restored_hidden_states, hidden_states) # 反向传播
assert torch.allclose(hidden_states.grad, ans), "反向传播梯度不匹配"
6. 清理实验性配置
config.ENABLE_EXPERIMENTAL = False # 重置实验性开关(避免影响其他测试)
总结:调用流程与配置关系
-
pytest 参数驱动:通过
@pytest.mark.parametrize生成多组tp_size/ep_size/fusion组合,覆盖不同并行和优化场景。 -
容器初始化:
MoEModelTestContainer封装模型并行环境、TransformerConfig和MoELayer,统一管理测试依赖。 -
Flex 调度器测试:核心通过
dispatcher_dropless_test验证令牌分发/合并的前向/反向正确性,确保 Flex 调度器在各种配置下功能正常。
整个流程围绕 Flex 令牌调度器 的核心功能展开,通过参数化和断言验证,确保其在不同并行策略、融合优化下的鲁棒性。
test_moe_deepep.py 中 token_permutation 函数详解
函数定义
def token_permutation(token_dispatcher, hidden_states, probs, indices):
hidden_states, probs = token_dispatcher.dispatch_preprocess(hidden_states, indices, probs)
hidden_states, probs = token_dispatcher.token_dispatch(hidden_states, probs)
hidden_states, tokens_per_expert, permuted_probs = token_dispatcher.dispatch_postprocess(
hidden_states, probs
)
return hidden_states, tokens_per_expert, permuted_probs
输入参数
token_dispatcher:
-
类型:
MoEFlexTokenDispatcher实例 -
作用: MoE 令牌分发器,负责管理令牌到专家的路由和通信
hidden_states:
-
形状:
[S, B, H]或[S*B, H] -
含义:
-
S: 序列长度 -
B: 批次大小 -
H: 隐藏层维度 -
在
dispatch_preprocess中会被展平为[-1, H],即[S*B, H]
probs:
-
形状:
[S, B, E]或[S*B, E] -
含义:
-
E: 专家数量 -
表示每个令牌路由到各个专家的概率
-
在
dispatch_preprocess中会被展平为[-1, E],即[S*B, E]
indices:
-
形状:
[S, B, K]或[S*B, K] -
含义:
-
K: 每个令牌选择的专家数量(top-K) -
表示每个令牌应该路由到的专家索引
-
在
dispatch_preprocess中会被展平为[-1, K],即[S*B, K]
处理步骤
dispatch_preprocess:
-
输入:
-
hidden_states:[S*B, H] -
indices:[S*B, K] -
probs:[S*B, K](注意:这里 probs 通常是从 indices 和原始 probs 计算得出的) -
输出:
-
hidden_states:[S*B, H] -
probs:[S*B, K] -
主要操作:
-
将输入张量展平为二维
-
调用
_initialize_metadata将路由图和概率扩展到 TPxEP 组,形状变为[S*B, world_size, num_local_experts] -
通过
_comm_manager.setup_metadata设置通信元数据
token_dispatch:
-
输入:
-
hidden_states:[S*B, H] -
probs:[S*B, K] -
输出:
-
hidden_states:[num_dispatched_tokens, H] -
probs:[num_dispatched_tokens, K] -
主要操作:
-
利用 DeepEP 的融合调度内核执行置换和 AlltoAll 通信
-
返回分发后的隐藏状态和概率
dispatch_postprocess:
-
输入:
-
hidden_states:[num_dispatched_tokens, H] -
probs:[num_dispatched_tokens, K] -
输出:
-
hidden_states:[sum(tokens_per_expert), H] -
tokens_per_expert:[num_local_experts] -
permuted_probs:[sum(tokens_per_expert), K] -
主要操作:
-
通过
_comm_manager.get_permuted_hidden_states_by_experts将分发后的令牌转换为每个专家的格式 -
通过
_comm_manager.get_number_of_tokens_per_expert获取每个专家的令牌数量
输出参数
hidden_states:
-
形状:
[sum(tokens_per_expert), H] -
含义: 分发并排列后的隐藏状态,准备输入到专家网络
-
sum(tokens_per_expert): 所有本地专家接收到的令牌总数
tokens_per_expert:
-
形状:
[num_local_experts] -
含义: 每个本地专家接收到的令牌数量
-
用于指导专家网络的计算
permuted_probs:
-
形状:
[sum(tokens_per_expert), K] -
含义: 排列后的路由概率,与
hidden_states对应 -
用于在专家计算后进行加权组合
维度一致性分析
在整个处理流程中,维度的变化遵循以下规律:
-
输入阶段: 张量被展平为二维
[N, H]和[N, K]形式,其中N是令牌总数 -
通信阶段: 利用 DeepEP 的融合内核执行 AlltoAll 通信,改变令牌的分布
-
输出阶段: 张量被重新组织为每个专家的格式,形状变为
[sum(tokens_per_expert), H]
这种设计使得 MoE 层能够高效地处理令牌分发和专家计算,同时保持与底层并行策略的解耦。