[perf] replace all_reduce for kv_consumer and support different num_tokens among all ranks (#4983)
pick from https://github.com/vllm-project/vllm-ascend/pull/4736 to fix the merge conflict ### What this PR does / why we need it? Currently, the all_reduce operation in _sync_metadata_across_dp is performed with gloo backend which is extremely time-consuming when DPEngineCores are in different nodes. This operation cannot be ignored by async scheduling in multi-node-scenarios with speculative decoding (e.g., EAGLE, mtp). This pr eliminates the all_reduce operation for D Nodes and change the input parameter of MoEDispatch & MoeCombine operators to make MC2EP support different num_tokens across all ranks. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tested with PD disaggregation (2P: DP2TP8EP16 1D: DP8TP4EP32) scenarios while enabling async scheduling. This pr can remove cross-node all_reduce with gloo backend and further reduce latency with correct accuracy. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -371,8 +371,12 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(
|
||||
topk_ids, 0, global_num_experts - global_redundant_expert_num)
|
||||
random_matrix = torch.rand(topk_ids.size(0),
|
||||
global_num_experts -
|
||||
global_redundant_expert_num,
|
||||
device=topk_ids.device)
|
||||
topk_ids = torch.argsort(
|
||||
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
|
||||
@@ -215,8 +215,12 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(
|
||||
topk_ids, 0, global_num_experts - global_redundant_expert_num)
|
||||
random_matrix = torch.rand(topk_ids.size(0),
|
||||
global_num_experts -
|
||||
global_redundant_expert_num,
|
||||
device=topk_ids.device)
|
||||
topk_ids = torch.argsort(
|
||||
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
||||
|
||||
topk_weights = topk_weights.to(self.in_dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user