[(5 条消息) 图解大模型训练系列之:DeepSpeed-Megatron MoE 并行训练(原理篇) - 知乎](图解大模型训练系列之:DeepSpeed-Megatron MoE并行训练(原理篇))

Pytest 设置

test_token_dispatcher.pyTestFlexDispatcher

1. TestFlexDispatcher 测试类结构

TestFlexDispatcher 是针对 Flex 令牌调度器 的测试类,使用 pytest 框架组织测试用例。核心测试方法包括:

  • test_forward_backward:验证前向/反向传播正确性

  • test_capacity_forward_backward:验证专家容量控制逻辑

  • test_router_padding_for_fp8_forward_backward:验证 FP8 训练时的路由填充逻辑

2. pytest 核心机制应用

(1)测试前置/后置处理

 
class TestFlexDispatcher:
 
def setup_method(self, method):
 
pass # 测试方法执行前的初始化(当前为空)
 
  
 
def teardown_method(self, method):
 
Utils.destroy_model_parallel() # 测试后清理模型并行状态
 
  • setup_method:每个测试方法执行前调用(此处未使用,可用于初始化共享资源)

  • teardown_method:每个测试方法执行后调用,确保模型并行状态正确销毁,避免测试污染

(2)条件跳过测试

 
@pytest.mark.skipif(not is_deep_ep_available(), reason="Deep EP is not available")
 
class TestFlexDispatcher: ...
 
  • 使用 @pytest.mark.skipif 标记类级别跳过条件:当 Deep EP 不可用时,整个 TestFlexDispatcher 类的测试会被跳过

(3)参数化测试(核心)

通过 @pytest.mark.parametrize 实现多场景测试,自动生成测试用例组合:

 
@pytest.mark.parametrize("tp_size,ep_size", [(8, 1), (1, 8), (2, 4)])
 
@pytest.mark.parametrize("permute_fusion", permute_fusion_params) # [False, True](取决于TE版本)
 
@pytest.mark.parametrize("experimental_fusion", [True, False])
 
def test_forward_backward(self, tp_size, ep_size, permute_fusion, experimental_fusion):
 
...
 
  • 参数组合逻辑:3(tp_size/ep_size)× 2(permute_fusion)× 2(experimental_fusion)= 12 个测试用例

  • 动态参数来源permute_fusion_params 根据 Transformer Engine 版本动态生成([False][False, True]

3. 测试流程与输入输出验证

test_forward_backward 为例,完整测试流程如下:

(1)测试环境准备

 
container = MoEModelTestContainer(
 
tp_size=tp_size,
 
ep_size=ep_size,
 
pp_size=1,
 
num_moe_experts=8,
 
moe_token_dispatcher_type="flex", # 指定 Flex 调度器
 
moe_permute_fusion=permute_fusion,
 
experimental_fusion=experimental_fusion, # 控制实验性融合开关
 
...
 
)
 
  • MoEModelTestContainer 是测试容器类,封装了 MoE 层初始化、并行环境配置等逻辑

  • 通过构造参数控制测试场景(如张量并行大小 tp_size、专家并行大小 ep_size 等)

(2)执行测试用例

 
container.dispatcher_dropless_test() # 调用具体测试逻辑
 

dispatcher_dropless_test 方法实现核心验证逻辑:

  • 输入:随机生成的 hidden_states(形状 [bs, seql, hidden_size]

  • 处理流程

  1. 路由计算:probs, indices = moe_layer.router(hidden_states)

  2. 令牌调度:token_permutation(分发→排列→后处理)

  3. 令牌反调度:token_unpermutation(合并→恢复)

  • 输出验证
 
# 验证恢复的 hidden_states 与原始输入一致
 
assert torch.allclose(restored_hidden_states, ans)
 
# 验证反向传播梯度正确性
 
torch.autograd.backward(restored_hidden_states, hidden_states)
 
assert torch.allclose(hidden_states.grad, ans)
 

4. 关键测试场景说明

| 测试方法 | 核心验证目标 | 输入特点 | 输出验证指标 |

| ---------------------------------- | ------------------------------ | ------------------- | -------------------- |

| test_forward_backward | Flex 调度器基本前向/反向传播正确性 | 随机 hidden_states | 恢复的特征、梯度与原始输入一致 |

| test_capacity_forward_backward | 专家容量控制(令牌数量不超过专家容量) | 固定 num_tokens=16 | 令牌数 ≤ 专家容量,输出特征正确性 |

| test_router_padding_for_fp8_test | FP8 训练时路由映射填充逻辑(令牌数对齐到 16 的倍数) | num_tokens=32 | 填充前后输出一致,令牌数为 16 的倍数 |

5. pytest 执行与结果呈现

  • 用例发现:pytest 自动识别 TestFlexDispatcher 类中以 test_ 开头的方法

  • 测试报告:每个参数组合生成独立的测试结果(如 test_forward_backward[8-1-True-True]

  • 失败处理:若 assert 断言失败(如输出特征不一致),pytest 会显示详细的张量差异信息

总结

该文件通过 pytest 的 参数化测试 实现了对 Flex 令牌调度器多场景覆盖,结合 测试容器类 封装环境配置,通过 输入随机化输出断言 验证核心功能正确性。这种设计确保了测试的可扩展性(新增参数组合即可覆盖新场景)和可靠性(自动清理环境 + 独立测试用例)。

使用 Pytest 测试 TestFlexDispatcher

一、单独测试 TestFlexDispatcher 的命令

TestFlexDispatcher 是针对 Flex 令牌调度器 的测试类,使用 pytest 单独运行其测试的命令如下:

 
pytest tests/unit_tests/transformer/moe/test_token_dispatcher.py -k "TestFlexDispatcher" -v
 

参数说明

  • tests/unit_tests/transformer/moe/test_token_dispatcher.py:指定测试文件路径。

  • -k "TestFlexDispatcher":通过关键字匹配,仅运行类名包含 TestFlexDispatcher 的测试(过滤其他测试类,如 TestAllgatherDispatcher)。

  • -v:显示详细测试日志(每个子用例的执行结果)。

二、pytest 核心语法在文件中的应用

1. 测试类定义与生命周期管理

 
@pytest.mark.skipif(not is_deep_ep_available(), reason="Deep EP is not available")
 
class TestFlexDispatcher:
 
def setup_method(self, method):
 
pass # 测试方法执行前的初始化(当前为空,可用于创建共享资源)
 
  
 
def teardown_method(self, method):
 
Utils.destroy_model_parallel() # 测试后清理模型并行状态,避免影响其他测试
 
  • @pytest.mark.skipif(类级别跳过)

条件判断是否跳过整个类。例如,若 Deep EP 库不可用(not is_deep_ep_available()),则 TestFlexDispatcher 所有测试方法均不执行。

  • setup_method/teardown_method:

  • setup_method:每个测试方法(如 test_forward_backward)执行前调用,用于初始化测试环境(如创建临时文件、初始化模型)。

  • teardown_method:每个测试方法执行后调用,用于清理资源(如销毁分布式环境、释放 GPU 内存)。

2. 测试方法与参数化(核心功能)

test_capacity_forward_backward 为例, pytest 通过 参数化 实现多场景自动测试:

 
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
 
@pytest.mark.internal # 自定义标记(可能用于分类测试,如内部测试/公开测试)
 
@pytest.mark.timeout(120) # 超时控制:测试超过 120 秒则失败
 
@pytest.mark.parametrize("tp_size,ep_size", [(1, 8), (8, 1), (4, 2)]) # 参数组合1
 
@pytest.mark.parametrize("permute_fusion", permute_fusion_params) # 参数组合2(动态生成)
 
@pytest.mark.parametrize("experimental_fusion", [True, False]) # 参数组合3
 
def test_capacity_forward_backward(self, tp_size, ep_size, permute_fusion, experimental_fusion):
 
# 测试逻辑:初始化容器 → 执行测试 → 验证结果
 
container = MoEModelTestContainer(tp_size=tp_size, ep_size=ep_size, ...)
 
container.dispatcher_capacity_test() # 核心测试逻辑
 

关键语法解析

  • @pytest.mark.parametrize(参数化测试)

自动生成多组测试用例,每组参数对应一个独立的测试。例如上述代码中:

  • tp_size,ep_size 有 3 种组合:(1,8)(8,1)(4,2)

  • permute_fusionpermute_fusion_params 动态生成(如 [False, True],取决于 Transformer Engine 版本)。

  • experimental_fusion 有 2 种组合:True/False

总测试用例数3(tp/ep) × 2(permute) × 2(experimental) = 12 个,pytest 会自动运行所有组合并报告每个用例的结果。

  • @pytest.mark.timeout(120)

设置单个测试方法的超时时间(120 秒),防止测试因死锁或性能问题无限期阻塞。

  • @pytest.mark.skipif(方法级别跳过)

比类级别更细粒度的跳过条件。例如,若没有 CUDA 环境(not torch.cuda.is_available()),则跳过该测试方法。

3. 测试逻辑与输入输出验证

测试方法的核心流程是:构造输入 → 执行待测试逻辑 → 验证输出。以 dispatcher_capacity_test 为例(被 test_capacity_forward_backward 调用):

输入构造

 
# 生成随机输入张量(模拟模型中间层输出)
 
hidden_states = torch.randn((num_tokens, hidden_size)).cuda()
 
hidden_states.requires_grad = True # 启用梯度计算,用于反向传播测试
 

核心逻辑执行

 
# 路由计算(专家选择概率和索引)
 
probs, indices = moe_layer.router(hidden_states)
 
# 令牌调度(分发到专家)
 
permuted_input, tokens_per_expert, permuted_probs = token_permutation(...)
 
# 令牌合并(从专家收集结果)
 
restored_hidden_states = token_unpermutation(...)
 

输出验证(断言)

 
# 验证专家接收的令牌数不超过容量上限
 
assert torch.all(tokens_per_expert <= capacity), "令牌数超过专家容量"
 
# 验证反向传播梯度正确性(张量接近度比较,允许微小误差)
 
assert torch.allclose(hidden_states.grad, expected_grad), "梯度计算错误"
 
  • torch.allclose:由于浮点数计算误差,不直接用 == 比较张量,而是检查是否在允许的误差范围内(如 rtol=1e-5atol=1e-8)。

三、常用 Pytest 命令扩展

| 命令示例 | 作用 |

|----------|------|

| pytest -k TestFlexDispatcher -v | 仅运行 TestFlexDispatcher 类,显示详细日志 |

| pytest -s | 显示测试中的 print 输出(默认 pytest 会捕获输出) |

| pytest --lf | 仅运行上次失败的测试(快速复现问题) |

| pytest --cov=megatron.core.transformer.moe | 生成测试覆盖率报告(需安装 pytest-cov) |

总结

TestFlexDispatcher 通过 pytest 的 参数化条件跳过生命周期管理断言验证,实现了对 Flex 令牌调度器多场景(不同并行策略、融合开关、容量配置)的自动化测试。核心步骤是:

  1. @pytest.mark.parametrize 定义测试参数组合;

  2. MoEModelTestContainer 构造测试环境;

  3. 调用具体测试方法(如 dispatcher_capacity_test)执行逻辑;

  4. assert + torch.allclose 验证输出和梯度正确性。

TestFlexDispatcher 中三种测试方法的 MoE 场景与参数设计解析

一、test_forward_backward:基础前向/反向传播功能测试

测试目标:验证 Flex 令牌调度器在 无令牌丢弃(dropless)场景 下的核心功能正确性,包括令牌分发、合并及梯度传播。

核心验证逻辑

通过 dispatcher_dropless_test 实现:

  1. 生成随机输入 hidden_states,经 MoE 路由(router)计算专家选择概率(probs)和索引(indices)。

  2. 调用 token_permutation(令牌分发)和 token_unpermutation(令牌合并),验证合并后的 restored_hidden_states 与原始输入一致。

  3. 验证反向传播梯度(hidden_states.grad)与原始输入的一致性,确保梯度计算未被调度逻辑破坏。

参数设计原因

 
@pytest.mark.parametrize("tp_size,ep_size", [(8, 1), (1, 8), (2, 4)]) # 并行策略组合
 
@pytest.mark.parametrize("permute_fusion", permute_fusion_params) # 融合开关
 
@pytest.mark.parametrize("experimental_fusion", [True, False]) # 实验性融合开关
 
def test_forward_backward(...): ...
 
  • tp_size/ep_size: 测试不同并行策略组合(张量并行/专家并行),确保调度器在模型并行拆分时仍能正确路由令牌。例如:

  • (8, 1):8 路张量并行,专家不并行(所有专家在单个 GPU)。

  • (1, 8):专家并行(8 个专家分布在 8 个 GPU),无张量并行。

  • (2, 4):混合并行(2 路张量并行 + 4 路专家并行)。

  • permute_fusion/experimental_fusion: 测试调度器与融合优化(如算子融合、内存优化)的兼容性,确保功能正确性不受性能优化影响。

二、test_capacity_forward_backward:专家容量控制测试

测试目标:验证 Flex 调度器在 专家容量限制场景 下的令牌丢弃/截断逻辑,确保专家接收的令牌数不超过其容量上限。

核心验证逻辑

通过 dispatcher_capacity_test 实现:

  1. 计算专家容量(capacity):基于 moe_expert_capacity_factor(容量系数)、令牌总数和专家数量。

  2. 验证 tokens_per_expert(每个专家实际接收的令牌数)是否小于等于 capacity * ep_size * tp_size(考虑并行拆分后的实际容量)。

  3. 验证令牌合并后的输出和梯度与预期结果一致,确保容量控制未破坏计算正确性。

参数设计原因

 
container = MoEModelTestContainer(
 
moe_expert_capacity_factor=0.5, # 容量系数(控制专家最大令牌数)
 
moe_token_drop_policy="probs", # 令牌丢弃策略(按概率丢弃超额令牌)
 
hidden_size=4, # 小隐藏维度加速测试
 
...
 
)
 
  • moe_expert_capacity_factor=0.5: 将专家容量设为默认值的 50%,刻意降低容量以触发令牌丢弃逻辑,验证调度器是否能正确截断超额令牌。

  • moe_token_drop_policy="probs": 测试按概率丢弃令牌的策略(而非随机丢弃),确保高概率令牌优先保留。

  • hidden_size=4: 使用小维度输入加速测试,减少计算资源消耗。

三、test_router_padding_for_fp8_test:FP8 训练场景下的路由填充测试

测试目标:验证 Flex 调度器在 FP8 混合精度训练 时的令牌填充逻辑,确保专家输入长度对齐到 16 的倍数(FP8 kernels 要求),且填充不影响计算结果。

核心验证逻辑

通过 dispatcher_router_padding_for_fp8_test 实现:

  1. 对比两种模式:moe_router_padding_for_fp8=False(无填充)和 True(填充到 16 的倍数)。

  2. 验证填充模式下 tokens_per_expert 为 16 的倍数(tokens_per_expert_2 % 16 == 0)。

  3. 验证两种模式的输出 restored_hidden_states 和梯度完全一致,确保填充未引入计算误差。

参数设计原因

 
container = MoEModelTestContainer(
 
num_moe_experts=32, # 更多专家,增加令牌分布随机性
 
moe_router_topk=4, # 更高 top-k,增加每个令牌选择的专家数
 
fp8="hybrid", # 启用 FP8 混合精度
 
...
 
)
 
  • num_moe_experts=32/moe_router_topk=4: 增加专家数量和每个令牌选择的专家数,使令牌分布更分散,更容易出现非 16 倍数的令牌数,从而测试填充逻辑的鲁棒性。

  • fp8="hybrid": 显式启用 FP8 模式,确保测试环境与实际 FP8 训练一致。

  • @pytest.mark.skipif(not is_te_min_version("1.7.0")): FP8 路由填充依赖 Transformer Engine 1.7.0+ 的特性,低版本 TE 不支持,故跳过。

总结:三种测试的互补性

| 测试方法 | 核心场景 | 关键参数设计逻辑 |

|-----------------------------------|---------------------------|--------------------------------------------------------------------------------|

| test_forward_backward | 基础功能验证(无丢弃) | 多并行策略组合,验证调度器在不同模型拆分下的正确性。 |

| test_capacity_forward_backward | 专家容量控制(有丢弃) | 低容量系数触发丢弃逻辑,验证令牌截断和专家负载均衡。 |

| test_router_padding_for_fp8_test| FP8 训练兼容性(需填充) | 更多专家/更高 top-k 增加填充必要性,验证 FP8 kernels 所需的输入对齐要求。 |

三者共同覆盖了 MoE 调度器的 基础功能、容量控制、精度兼容性 三大核心场景,参数设计均针对场景痛点(如并行拆分、容量超限、FP8 对齐),确保测试的全面性和有效性。

基础测试流程

TestFlexDispatcher.test_forward_backward 调用流程与模型配置解析

1. 测试方法定义与参数化(pytest 驱动)

 
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") # 条件跳过:无CUDA时不执行
 
@pytest.mark.internal # 内部测试标记(用于分类测试)
 
@pytest.mark.parametrize("tp_size,ep_size", [(8, 1), (1, 8), (2, 4)]) # 参数组合1:张量并行/专家并行策略
 
@pytest.mark.parametrize("permute_fusion", permute_fusion_params) # 参数组合2:算子融合开关(True/False)
 
@pytest.mark.parametrize("experimental_fusion", [True, False]) # 参数组合3:实验性融合开关
 
def test_forward_backward(self, tp_size, ep_size, permute_fusion, experimental_fusion): # 参数化输入
 

作用

  • @pytest.mark.parametrize 生成多组测试用例(如 tp_size=8,ep_size=1tp_size=2,ep_size=4),覆盖不同并行策略和融合配置。

  • tp_size/ep_size:控制模型并行拆分方式(张量并行/专家并行),验证 Flex 调度器在不同并行场景下的兼容性。

2. 实验性融合开关配置

 
if experimental_fusion:
 
config.ENABLE_EXPERIMENTAL = True # 启用实验性融合(如内存优化或算子合并)
 

作用

  • 动态控制全局配置 config.ENABLE_EXPERIMENTAL,测试调度器与实验性优化的兼容性。

3. 创建 MoE 测试容器(核心配置入口)

 
container = MoEModelTestContainer(
 
tp_size=tp_size, # 张量并行大小(来自参数化输入)
 
ep_size=ep_size, # 专家并行大小(来自参数化输入)
 
pp_size=1, # 流水线并行大小(固定为1,无需测试)
 
num_moe_experts=8, # 总专家数:8个专家
 
moe_router_topk=2, # 路由选择Top-2专家(每个令牌分配给2个专家)
 
moe_router_load_balancing_type="aux_loss", # 负载均衡策略:辅助损失函数
 
moe_token_dispatcher_type="flex", # 令牌调度器类型:Flex(测试目标)
 
moe_permute_fusion=permute_fusion, # 排列融合开关(来自参数化输入)
 
hidden_size=4, # 隐藏层维度(小维度加速测试)
 
moe_enable_deepep=True, # 启用Deep EP优化(Flex调度器依赖)
 
)
 

关键配置解析

  • 并行策略tp_size(张量并行)和 ep_size(专家并行)决定模型拆分方式。例如,tp_size=8,ep_size=1 表示 8 路张量并行、专家不并行;ep_size=8 表示 8 个专家分布在 8 个 GPU。

  • MoE 核心参数

  • num_moe_experts=8:总专家数为 8,与 ep_size 共同决定每个 GPU 的本地专家数(num_local_experts = 8 // ep_size)。

  • moe_router_topk=2:每个令牌路由到 2 个专家,验证多专家分配场景。

  • moe_token_dispatcher_type="flex":显式指定使用 Flex 调度器(测试核心目标)。

4. MoEModelTestContainer 初始化流程(模型环境搭建)

MoEModelTestContainer__init__ 方法完成以下关键步骤:

(1)模型并行环境初始化

 
Utils.initialize_model_parallel(
 
tensor_model_parallel_size=tp_size,
 
expert_model_parallel_size=ep_size,
 
... # 其他并行参数
 
)
 

作用:通过 Utils 工具类设置分布式环境,包括张量并行组、专家并行组的初始化,确保各 GPU 间通信正常。

(2)生成 TransformerConfig(模型配置核心)

 
self.config = TransformerConfig(
 
tensor_model_parallel_size=tp_size,
 
expert_model_parallel_size=ep_size,
 
moe_token_dispatcher_type="flex", # 绑定Flex调度器
 
hidden_size=4, # 小隐藏维度(减少计算量)
 
... # 其他MoE配置(如容量系数、融合开关)
 
)
 

作用:创建 TransformerConfig 对象,集中管理模型架构、并行策略、MoE 特性等配置,后续传递给 MoELayer

(3)初始化 MoE 层

 
self.moe_layer = self.new_moe_layer() # 调用new_moe_layer创建MoE层
 

new_moe_layer 方法细节:

 
def new_moe_layer(self, **kargs):
 
transformer_layer_spec = get_gpt_layer_local_spec(...) # 获取GPT层规格(含MLP子模块)
 
new_config = dataclasses.replace(self.config, **kargs) # 合并配置(支持动态修改)
 
moe_layer = MoELayer(new_config, transformer_layer_spec.submodules.mlp.submodules).cuda() # 创建MoE层并移至GPU
 
return moe_layer
 

作用

  • MoELayer 是核心 MoE 组件,包含路由模块(router)和 Flex 令牌调度器(token_dispatcher)。

  • transformer_layer_spec 指定 MLP 子模块结构,确保专家网络符合 GPT 架构。

5. 执行测试逻辑:dispatcher_dropless_test

 
container.dispatcher_dropless_test() # 调用无令牌丢弃场景的测试
 

该方法是实际测试逻辑的入口,内部流程如下:

(1)生成测试输入

 
bs = 32 # batch size
 
seql = 8 # 序列长度
 
hidden_states = torch.randn((bs, seql, moe_layer.config.hidden_size)).cuda() # 随机输入张量 [32,8,4]
 
hidden_states.requires_grad = True # 启用梯度计算(用于反向传播测试)
 

(2)路由计算(专家选择)

 
probs, indices = moe_layer.router(hidden_states) # 路由输出:专家选择概率(probs)和索引(indices)
 
probs = torch.ones_like(probs) / moe_layer.router.topk # 归一化概率(等概率选择Top-k专家)
 

(3)Flex 调度器令牌分发与合并

 
# 令牌分发:将hidden_states按indices分配给对应专家
 
permuted_local_hidden_states, tokens_per_expert, permuted_probs = token_permutation(
 
moe_layer.token_dispatcher, hidden_states, probs, indices
 
)
 
  
 
# 令牌合并:从专家收集结果并恢复原始顺序
 
restored_hidden_states, restored_bias = token_unpermutation(
 
moe_layer.token_dispatcher, permuted_local_hidden_states
 
)
 

(4)结果验证(断言)

 
# 验证前向传播:恢复的hidden_states与原始输入一致
 
assert torch.allclose(restored_hidden_states, ans), "前向传播结果不匹配"
 
  
 
# 验证反向传播:梯度与原始输入一致
 
torch.autograd.backward(restored_hidden_states, hidden_states) # 反向传播
 
assert torch.allclose(hidden_states.grad, ans), "反向传播梯度不匹配"
 

6. 清理实验性配置

 
config.ENABLE_EXPERIMENTAL = False # 重置实验性开关(避免影响其他测试)
 

总结:调用流程与配置关系

  1. pytest 参数驱动:通过 @pytest.mark.parametrize 生成多组 tp_size/ep_size/fusion 组合,覆盖不同并行和优化场景。

  2. 容器初始化MoEModelTestContainer 封装模型并行环境、TransformerConfigMoELayer,统一管理测试依赖。

  3. Flex 调度器测试:核心通过 dispatcher_dropless_test 验证令牌分发/合并的前向/反向正确性,确保 Flex 调度器在各种配置下功能正常。

整个流程围绕 Flex 令牌调度器 的核心功能展开,通过参数化和断言验证,确保其在不同并行策略、融合优化下的鲁棒性。

test_moe_deepep.pytoken_permutation 函数详解

函数定义

 
def token_permutation(token_dispatcher, hidden_states, probs, indices):
 
hidden_states, probs = token_dispatcher.dispatch_preprocess(hidden_states, indices, probs)
 
hidden_states, probs = token_dispatcher.token_dispatch(hidden_states, probs)
 
hidden_states, tokens_per_expert, permuted_probs = token_dispatcher.dispatch_postprocess(
 
hidden_states, probs
 
)
 
return hidden_states, tokens_per_expert, permuted_probs
 

输入参数

  1. token_dispatcher:
  • 类型: MoEFlexTokenDispatcher 实例

  • 作用: MoE 令牌分发器,负责管理令牌到专家的路由和通信

  1. hidden_states:
  • 形状: [S, B, H][S*B, H]

  • 含义:

  • S: 序列长度

  • B: 批次大小

  • H: 隐藏层维度

  • dispatch_preprocess 中会被展平为 [-1, H],即 [S*B, H]

  1. probs:
  • 形状: [S, B, E][S*B, E]

  • 含义:

  • E: 专家数量

  • 表示每个令牌路由到各个专家的概率

  • dispatch_preprocess 中会被展平为 [-1, E],即 [S*B, E]

  1. indices:
  • 形状: [S, B, K][S*B, K]

  • 含义:

  • K: 每个令牌选择的专家数量(top-K)

  • 表示每个令牌应该路由到的专家索引

  • dispatch_preprocess 中会被展平为 [-1, K],即 [S*B, K]

处理步骤

  1. dispatch_preprocess:
  • 输入:

  • hidden_states: [S*B, H]

  • indices: [S*B, K]

  • probs: [S*B, K] (注意:这里 probs 通常是从 indices 和原始 probs 计算得出的)

  • 输出:

  • hidden_states: [S*B, H]

  • probs: [S*B, K]

  • 主要操作:

  • 将输入张量展平为二维

  • 调用 _initialize_metadata 将路由图和概率扩展到 TPxEP 组,形状变为 [S*B, world_size, num_local_experts]

  • 通过 _comm_manager.setup_metadata 设置通信元数据

  1. token_dispatch:
  • 输入:

  • hidden_states: [S*B, H]

  • probs: [S*B, K]

  • 输出:

  • hidden_states: [num_dispatched_tokens, H]

  • probs: [num_dispatched_tokens, K]

  • 主要操作:

  • 利用 DeepEP 的融合调度内核执行置换和 AlltoAll 通信

  • 返回分发后的隐藏状态和概率

  1. dispatch_postprocess:
  • 输入:

  • hidden_states: [num_dispatched_tokens, H]

  • probs: [num_dispatched_tokens, K]

  • 输出:

  • hidden_states: [sum(tokens_per_expert), H]

  • tokens_per_expert: [num_local_experts]

  • permuted_probs: [sum(tokens_per_expert), K]

  • 主要操作:

  • 通过 _comm_manager.get_permuted_hidden_states_by_experts 将分发后的令牌转换为每个专家的格式

  • 通过 _comm_manager.get_number_of_tokens_per_expert 获取每个专家的令牌数量

输出参数

  1. hidden_states:
  • 形状: [sum(tokens_per_expert), H]

  • 含义: 分发并排列后的隐藏状态,准备输入到专家网络

  • sum(tokens_per_expert): 所有本地专家接收到的令牌总数

  1. tokens_per_expert:
  • 形状: [num_local_experts]

  • 含义: 每个本地专家接收到的令牌数量

  • 用于指导专家网络的计算

  1. permuted_probs:
  • 形状: [sum(tokens_per_expert), K]

  • 含义: 排列后的路由概率,与 hidden_states 对应

  • 用于在专家计算后进行加权组合

维度一致性分析

在整个处理流程中,维度的变化遵循以下规律:

  1. 输入阶段: 张量被展平为二维 [N, H][N, K] 形式,其中 N 是令牌总数

  2. 通信阶段: 利用 DeepEP 的融合内核执行 AlltoAll 通信,改变令牌的分布

  3. 输出阶段: 张量被重新组织为每个专家的格式,形状变为 [sum(tokens_per_expert), H]

这种设计使得 MoE 层能够高效地处理令牌分发和专家计算,同时保持与底层并行策略的解耦。