[perf] replace all_reduce for kv_consumer and support different num_tokens among all ranks (#4983)

pick from https://github.com/vllm-project/vllm-ascend/pull/4736 to fix
the merge conflict

### What this PR does / why we need it?
Currently, the all_reduce operation in _sync_metadata_across_dp is
performed with gloo backend which is extremely time-consuming when
DPEngineCores are in different nodes. This operation cannot be ignored
by async scheduling in multi-node-scenarios with speculative decoding
(e.g., EAGLE, mtp).

This pr eliminates the all_reduce operation for D Nodes and change the
input parameter of MoEDispatch & MoeCombine operators to make MC2EP
support different num_tokens across all ranks.

### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Tested with PD disaggregation (2P: DP2TP8EP16 1D: DP8TP4EP32) scenarios
while enabling async scheduling. This pr can remove cross-node
all_reduce with gloo backend and further reduce latency with correct
accuracy.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
wangxiyuan
2025-12-13 18:59:54 +08:00
committed by GitHub
parent 5211e991ad
commit fd7c929145
7 changed files with 69 additions and 26 deletions

View File

@@ -371,8 +371,12 @@ class AscendW4A8DynamicFusedMoEMethod:
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(
topk_ids, 0, global_num_experts - global_redundant_expert_num)
random_matrix = torch.rand(topk_ids.size(0),
global_num_experts -
global_redundant_expert_num,
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
topk_weights = topk_weights.to(x.dtype)

View File

@@ -215,8 +215,12 @@ class AscendW8A8DynamicFusedMoEMethod:
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(
topk_ids, 0, global_num_experts - global_redundant_expert_num)
random_matrix = torch.rand(topk_ids.size(0),
global_num_experts -
global_redundant_expert_num,
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
topk_weights = topk_weights.to(self.in_dtype)