Megatron-LM走DeepEP的处理流程
DeepEP dispatcher的主要思想是:
(1) token permutation阶段:调用fused_dispatch接口对tokens在tp_ep group内进行all to all 和 all gather操作,输出属于本卡local experts的global hidden_states,再按照local expert indices执行permute操作,即可得到experts MLP的输出;
(2) token unpermutation阶段:对experts MLP的输出按照indices和probs进行unpermute,然后调用fused_combine接口对tokens在tp_ep group内进行reduce_scatter和all to all操作,得到最终的输出。
问题:DeepEP并没有先对local tokens做permute,而是先做的all to all通信,是否节省了通信量?
token permutation
这里对应:
# tests\unit_tests\test_moe_deepep.py
# @@ 19
# test 中的调用dispatcher
def token_permutation(token_dispatcher, hidden_states, probs, indices):
# megatron\core\transformer\moe\token_dispatcher.py
# @@ 982
def dispatch(
# megatron\core\transformer\moe\fused_a2a.py
# @@ 72
# 调用DeepEP
def forward(
(1) 数据预处理:将hidden_states reshape为(s*b, h),同时处理probs,生成处理token_indices和token_probs矩阵,为token的fused_dispatcher融合算子提供输入;
(2) fused dispatch: 用fused cuda kernel对tokens在tp_ep group内进行all to all和all gather操作,输出得到:
-
global_hidden_states: global tokens被分派到本卡local_experts上的hidden_states,shape为(global_num, h)
-
dispatched_indices: global tokens被分派到本卡local_experts上的expert indices,shape为(global_num, topk)
-
dispatched_probs: global tokens被分派到本卡local_experts上的expert probs,shape为(global_num, topk)
-
num_tokens_per_local_expert: 被分派到本卡local_experts上的global tokens数量
-
NOTE: dispatched_probs和dispatched_indices中存在**-1**的值,表示不属于本卡的local experts,非-1的值表示属于本卡的local experts
-
问题?:all to all的input splits如何定义?input_splits=(ep_rank)
(3) 将expert indices和expert probs转化为multihot format:
-
将dispatched_indices转化为multi-hot格式:(global_num, topk) → (global_num, num_experts)
-
将dispatched_probs转化为multi-hot格式:(global_num, topk) → (global_num, num_experts)
(4) 针对global hidden_states执行permute:
-
根据multi-hot dispatched_indices对global hidden_states进行permute,得到experts的输入,该输入已经按照expert indices排序,无需再做sort_chunk
token unpermutation
(1) unpermute experts的输出:
-
根据multi-hot dispatched_indices和dispatched_probs,对experts的输出做unpermute
(2) fused combine: 用fused cuda kernel对tokens在tp_ep group内进行reduce scatter和all to all操作,得到hidden_states并reshape为(s, b, h)
| |
|---|
|if config.moe_token_dispatcher_type =``= "flex"``:
self``.token_dispatcher = MoEFlexTokenDispatcher(
self``.num_local_experts, self``.local_expert_indices, config``=``self``.config
)|
Megatron-LM 不走DeepEP的流程
token permutation
(1) 数据预处理:
-
将hidden_states reshape为(s*b, h)
-
预处理routing_map,生成input_splits/output_splits/output_splits_tp等,为后面的all to all和all gather提供输入;
(2) hidden_states permute:
-
local_permuted_hidden_states: 根据routing_map对hidden_states进行permute,表示本卡local tokens被分派的全局experts上的hidden_states,shape为(local_num, h)
(2) first all to all:
-
global_hidden_states: 在ep_group内,针对local_permuted_hidden_states进行all to all通信,将global tokens分派到本卡local_experts上的hidden_states,shape为(global_num, h)
(3) tp allgather:
-
global_hidden_states: TP>1时,在tp_group内,对上面all to all的结果进行dim0维度的all gather,dim0将变大为(global_num2, h)
(4) sort by expert indice:
-
将global tokens的hidden states,按照local expert id排序,得到experts的输入
token unpermutation
(1) sort by expert indice:
-
experts计算出的outputs,做重排序
(2) tp reduce scatter:
-
如果TP>1,对outputs做reduce scatter
(3) second all to all:
-
local permuted hidden_states: 对global hidden_states做all to all,得到local tokens被分派的全局experts上的hidden_states,shape为(local_num, h)
(4) hidden_states unpermute:
-
hidden_states: 根据routing_map和probs对local permuted hidden_states进行unpermute,并reshape为(s, b, h)
DeepEP源码解析
deep_ep_cpp库安装
| |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| # project code: [[https://github.com/deepseek-ai/DeepEP](https://github.com/deepseek-ai/DeepEP)]([https://github.com/deepseek-ai/DeepEP](https://github.com/deepseek-ai/DeepEP))# install NVSHMEM# host: /dev/gdrdrvwget https:``//github``.com``/NVIDIA/gdrcopy/archive/refs/tags/v2``.4.4.``tar``.gzcd gdrcopy-2.4.4/make -j$(nproc)sudo make prefix=``/opt/gdrcopy installapt-get install devscriptscd packagesCUDA=``/usr/local/cuda .``/build-deb-packages``.shsudo dpkg -i gdrdrv-dkms_2.4.4_amd64.Ubuntu22_04.deb libgdrapi_2.4.4_amd64.Ubuntu22_04.deb gdrcopy-tests_2.4.4_amd64.Ubuntu22_04+cuda12.4.deb gdrcopy_2.4.4_amd64.Ubuntu22_04.debcd ..sudo .``/insmod``.sh# containerapt-get updateapt-get install dkmsdpkg -i libgdrapi_2.4.4_amd64.Ubuntu22_04.deb gdrcopy-tests_2.4.4_amd64.Ubuntu22_04+cuda12.4.deb gdrcopy_2.4.4_amd64.Ubuntu22_04.deb# hostsudo apt-get install libopenmpi-dev# Build and make symbolic links for SO filesNVSHMEM_DIR=``/path/to/installed/nvshmem python [setup.py](http://setup.py/) build# You may modify the specific SO names according to your own platformln -s build``/lib``.linux-x86_64-cpython-38``/deep_ep_cpp``.cpython-38-x86_64-linux-gnu.so# usagefrom deep_ep import Buffer, EventOverlap |
|dependency:
nvshmem|安装nvshmem依赖库,并添加到环境变量NVSHMEM_DIR中
[https://github.com/deepseek-ai/DeepEP/blob/main/third-party/README.md](https://github.com/deepseek-ai/DeepEP/blob/main/third-party/README.md)
| |
|---|
|nvshmem_dir = os.getenv(``'NVSHMEM_DIR'``, None)|||
|library:|nvshmem_dir/lib||
|source file:
csrc/deep_ep.cpp|||
|source file:
csrc/kernels/[runtime.cu](http://runtime.cu/)|提供一些辅助函数,例如类型转换、内存操作等,供其他内核调用。||
|source file:
csrc/kernels/[intranode.cu](http://intranode.cu/)|处理单个节点内的专家并行通信,利用NVLink实现高速数据传输。||
|source file:
csrc/kernels/[internode.cu](http://internode.cu/)|处理跨节点间的专家并行通信,利用RDMA实现数据传输。||
|source file:
csrc/kernels/internode_ll.cu|提供低延迟跨节点通信的内核,同样基于RDMA。||
|compile args|| |
|---|
|nvcc_dlink = [``'-dlink'``, f``'-L{nvshmem_dir}/lib'``, '-lnvshmem'``]extra_link_args = [``'-l:libnvshmem.a'``, '-l:nvshmem_bootstrap_uid.so'``, f``'-Wl,-rpath,{nvshmem_dir}/lib'``]extra_compile_args = {
'cxx'``: cxx_flags,
'nvcc'``: nvcc_flags,
'nvcc_dlink'``: nvcc_dlink}|||
|link_args|| |
|---|
|extra_link_args = [``'-l:libnvshmem.a'``, '-l:nvshmem_bootstrap_uid.so'``, f``'-Wl,-rpath,{nvshmem_dir}/lib'``]|||
DeepEP TestCases
| |
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| # trainingpython tests``/``test_intranode.pypython tests``/``test_internode.py# inferencepython tests``/``test_low_latency.py |
Runtime Buffer接口
features: core expert-parallel communication buffer, which support:
-
high-throughput intranode all-to-all (dispatch and combine, using NVLink)
-
high-throughput internode all-to-all (dispatch and combine, using RDMA without AR)
-
low-latency all-to-all (dispatch and combine, using RDMA, AR supported)
python接口:deep_ep/buffer.py
c++头文件:csrc/deep_ep.hpp
c++ source: csrc/deep_ep.cpp
Buffer attributes:
|group|通信group,EP group?|||
|group_size|group内的ranks数量=num_ranks|||
|rank|group内本卡的rank id|||
|num_nvl_bytes|机内NVLink通信的buffer size|||
|num_rdma_bytes|机间RDMA通信的buffer size|||
|low_latency_mode|是否使用low-latency模式,默认为False;
low-latency的kernel一般用于inference decoding阶段|||
|num_qps_per_rank|RDMA通信的QPs数量,low-latency模式下需等于local_experts的数量|||
|rdma_rank|rank / 8|||
|num_rdma_ranks|max(1, group_size / 8),与nodes数量相等?|||
|nvl_rank|rank % 8|||
|num_nvl_ranks|min(group_size, 8)|||
|ipc_handles[8]|cudaIpcMemHandle_t类型,每个device_rank拥有8个ipc_handle?|||
|task_fifo_ptrs[8]||||
Buffer Init
(1) 初始化runtime buffer:
| |
|---|
|self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode)|
(2) ep_group gather获取device_id、ipc_handle、nvshmem unique ids:
(a) ep_group gather所有卡的device_id:
| |
|---|
|device_ids = [``None``, ] * self``.group_sizelocal_device_id = self``.runtime.get_local_device_id()dist.all_gather_object(device_ids, local_device_id, group)|
(b) ep_group gather所有卡的进程间内存的ipc_handles:
| |
|---|
|ipc_handles = [``None``, ] * self``.group_sizelocal_ipc_handle = self``.runtime.get_local_ipc_handle()dist.all_gather_object(ipc_handles, local_ipc_handle, group)|
(c) ep_group gather所有卡的NVSHMEM unique IDs:
只有rdma_rank=0才获得nvshmem unique id
| |
|---|
|nvshmem_unique_ids = [``None``, ] * self``.group_sizeif (low_latency_mode and self``.rank =``= 0``) or (``not low_latency_mode and self``.runtime.get_rdma_rank() =``= 0``):
root_unique_id = self``.runtime.get_local_nvshmem_unique_id()dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group)root_unique_id = nvshmem_unique_ids[``0 if low_latency_mode else self``.runtime.get_root_rdma_rank(``True``)]|
(3) runtime同步device_id、ipc_handle、nvshmem unique id:
(a) 初始化buffer_ptrs和task_fifo_ptrs:
-
ipc_handles[i] 拷贝到cpu buffer_ptrs[i],存的是handle的地址;
-
从buffer_ptrs[i]拷贝数据到task_fifo_ptrs[i];
-
buffer_ptrs和task_fifo_ptrs拷贝至gpu
| |
|---|
|for (``int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) {
auto handle_str = std::string(all_gathered_handles[offset + i].value());
if (offset + i != rank) {
std::``memcpy``(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE);
task_fifo_ptrs[i] = reinterpret_cast``<``int``*>(``reinterpret_cast``<uint8_t*>(buffer_ptrs[i]) + num_nvl_bytes);
} else {
EP_HOST_ASSERT(std::``memcmp``(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0);
}}// Copy all buffer and task pointers to GPUCUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof``(``void``*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu, task_fifo_ptrs, sizeof``(``int``*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice));|
(b) 初始化NVSHMEM rdma_buffer_ptr:
| |
|---|
|std::vector<uint8_t> root_unique_id(root_unique_id_opt->size());auto root_unique_id_str = root_unique_id_opt->cast<std::string>();std::``memcpy``(root_unique_id.data(), root_unique_id_str.c_str(), root_unique_id_opt->size());auto nvshmem_rank = low_latency_mode ? rank : rdma_rank;auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks;internode::barrier();// Allocaterdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES);// Clean buffer (mainly for low-latency mode)CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes));// Barrierinternode::barrier();|
fused dispatch
设置Buffer的sm数量:
| |
|---|
|Buffer``.set_num_sms(``24``)# sm必须是偶数,一个channel用2个block,偶数block用于send,奇数block用于recvnum_channels = config.num_sms / 2``;|
初始化runtime Buffer:根据sm数量和group_size确定nvlink和rdma的通信buffer size
| |
|---|
|def get_buffer(group: dist.ProcessGroup, hidden_bytes: int``) -``> Buffer``:
global _buffer
# NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
num_nvl_bytes, num_rdma_bytes = 0``, 0
for config in (``Buffer``.get_dispatch_config(group.size()), Buffer``.get_combine_config(group.size())):
num_nvl_bytes = max``(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
num_rdma_bytes = max``(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
# Allocate a buffer if not existed or not enough buffer size
# NOTES: the adaptive routing configuration of the network **must be off**
if _buffer is None or _buffer.group !``= group or _buffer.num_nvl_bytes < num_nvl_bytes or _buffer.num_rdma_bytes < num_rdma_bytes:
_buffer = Buffer``(group, num_nvl_bytes, num_rdma_bytes)
return _buffer|
DeepEP 中提供了针对不同 Rank 数的配置,其中允许自由的配置 SM 数量(默认 20).
| |
|---|
|# Intranodeif num_ranks <``= 8``:
return Config(``Buffer``.num_sms, 6``, 256``, 6``, 128``)# Internodeconfig_map = {
16``: Config(``Buffer``.num_sms, 16``, 288``, 20``, 128``),
24``: Config(``Buffer``.num_sms, 8``, 288``, 32``, 128``),
32``: Config(``Buffer``.num_sms, 8``, 288``, 32``, 128``),
64``: Config(``Buffer``.num_sms, 20``, 288``, 28``, 128``),
128``: Config(``Buffer``.num_sms, 20``, 560``, 32``, 128``),
144``: Config(``Buffer``.num_sms, 32``, 720``, 12``, 128``),
160``: Config(``Buffer``.num_sms, 28``, 720``, 12``, 128``),}struct Config {
int num_sms;
int num_max_nvl_chunked_send_tokens;
int num_max_nvl_chunked_recv_tokens;
int num_max_rdma_chunked_send_tokens;
int num_max_rdma_chunked_recv_tokens;}|
在dispatch前计算layout:get_dispatch_layout
|num_tokens_per_rank|(num_ranks), int|发送给每个ep rank的tokens数量|
|num_tokens_per_rdma_rank|(num_rdma_ranks), int|发送给每个rdma rank的tokens数量
仅机内通信时,设置为None|
|is_token_in_rank|(num_tokens, num_ranks), bool|每个token是否发送给某个ep rank|
|num_tokens_per_expert|(num_experts), int|发送给每个expert的tokens数量|
|topk_idx|(num_tokens, topk), int64|每个tokens的topk expert|
python调用接口:get_dispatch_layout
| |
|---|
|num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event = \
_buffer.get_dispatch_layout(topk_idx, num_experts,
previous_event``=``previous_event, async_finish``=``True``,
allocate_on_comm_stream``=``previous_event is not None``)|
c++调用接口:get_dispatch_layout
| |
|---|
|void get_dispatch_layout(``const int64_t* topk_idx,
int``* num_tokens_per_rank, int``* num_tokens_per_rdma_rank,
int``* num_tokens_per_expert, bool``* is_token_in_rank,
int num_tokens, int num_topk, int num_ranks, int num_experts,
cudaStream_t stream) {
// 每个SM对应32个experts,8个local ranks
constexpr int kNumThreads = 256, kNumExpertsPerSM = 32, kNumRanksPerSM = 8;
// 计算所需的sm数量
int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM;
// 采用256个threads、num_sms个SM
SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream);
LAUNCH_KERNEL(&cfg, (get_dispatch_layout<kNumThreads, kNumExpertsPerSM, kNumRanksPerSM>),
topk_idx, num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank,
num_tokens, num_topk, num_ranks, num_experts);}|
buffer.dispatch:
将tokens分发给不同的ranks,机内kernels满足所有卡通过NVLink互联;
机间kernels除了满足机内的NVlink互联外,机间同号卡需通过RDMA互联,不可启用自适应路由。
|x|pattern1: (num_tokens, hidden), bfloat16
pattern2:
1. (num_tokens, hidden), torch.float8_e4m3fn
2. [num_tokens, hidden // 128], torch.float|input tensor
pattern1: bf16类型
pattern2: x_e4m3 = per_token_cast_to_fp8(x)|
|num_tokens_per_rank|(num_ranks), int|发送给每个ep rank的tokens数量,注意是topk去重后的tokens数量|
|num_tokens_per_rdma_rank|(num_rdma_ranks), int|发送给每个rdma rank的tokens数量
仅机内通信时,设置为None|
|is_token_in_rank|(num_tokens, num_ranks), bool|每个token是否发送给某个ep rank,数据举例:
针对ep_rank0,4096个token: {-1, -1, 0, 1, -1, -1, -1, 2, -1, …}
不属于该rank的token id设置为-1,输入该rank的token id从0开始递增。|
|num_tokens_per_expert|(num_experts), int|发送给每个expert的tokens数量|
|topk_idx|(num_tokens, topk), int64|每个tokens的topk expert idx,-1表示未被选择|
|topk_weights|(num_tokens, num_topk), torch.float|每个token分发给topk个expert的weight|
python调用接口:
| |
|---|
|recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args)|
intranode_dispatch接口:
| |
|---|
|recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \
self``.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights,
num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0``, None``, None``,
expert_alignment, config, getattr``(previous_event, 'event'``, None``), async_finish, allocate_on_comm_stream)|
创建计算和通信stream:
| |
|---|
|auto compute_stream = at::cuda::getCurrentCUDAStream();if (allocate_on_comm_stream) {
at::cuda::setCurrentCUDAStream(comm_stream);}|
notify_dispatch:
| |
|---|
|void notify_dispatch(``const int``* num_tokens_per_rank, int``* moe_recv_counter_mapped, int num_ranks,
const int``* num_tokens_per_expert, int``* moe_recv_expert_counter_mapped, int num_experts,
int num_tokens, const bool``* is_token_in_rank, int``* channel_prefix_matrix,
int``* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment,
void``** buffer_ptrs, int **task_fifo_ptrs, int head, int rank,
cudaStream_t stream, int num_channels) {#define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \
LAUNCH_KERNEL(&cfg, notify_dispatch<ranks>, \
num_tokens_per_rank, moe_recv_counter_mapped, \
num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \
num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \
rank_prefix_matrix_copy, num_memset_int, expert_alignment, \
buffer_ptrs, task_fifo_ptrs, head, rank); \
break
constexpr int kNumThreads = 128;
EP_HOST_ASSERT(num_experts % num_ranks == 0);
EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads);
SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream);
SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE);#undef NOTIFY_DISPATCH_LAUNCH_CASE}|
|per_rank_buffer[rank][i, j]|从rank_i发送到rank_j的tokens数量||
|per_expert_buffer[rank][i, j]|从rank_i发送给expert_j的tokens数量||
||||
internode_dispatch接口:
| |
|---|
|# Internodeif self``.runtime.get_num_rdma_ranks() > 1``:
return self``.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert,
topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream)|
fused combine
intranode_combine接口:
| |
|---|
|intranode_combine(``const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix,
const torch::Tensor& send_head, const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);|
internode_combine接口:
| |
|---|
|internode_combine(``const torch::Tensor& x, const std::optional<torch::Tensor>& topk_weights,
const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank,
const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix,
const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head,
const Config& config, std::optional<EventHandle>& previous_event, bool async, bool allocate_on_comm_stream);|
rdma send recv
csrc/kernels/internode.cu