[perf] replace all_reduce for kv_consumer and support different num_tokens among all ranks (#4983)

pick from https://github.com/vllm-project/vllm-ascend/pull/4736 to fix
the merge conflict

### What this PR does / why we need it?
Currently, the all_reduce operation in _sync_metadata_across_dp is
performed with gloo backend which is extremely time-consuming when
DPEngineCores are in different nodes. This operation cannot be ignored
by async scheduling in multi-node-scenarios with speculative decoding
(e.g., EAGLE, mtp).

This pr eliminates the all_reduce operation for D Nodes and change the
input parameter of MoEDispatch & MoeCombine operators to make MC2EP
support different num_tokens across all ranks.

### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Tested with PD disaggregation (2P: DP2TP8EP16 1D: DP8TP4EP32) scenarios
while enabling async scheduling. This pr can remove cross-node
all_reduce with gloo backend and further reduce latency with correct
accuracy.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
wangxiyuan
2025-12-13 18:59:54 +08:00
committed by GitHub
parent 5211e991ad
commit fd7c929145
7 changed files with 69 additions and 26 deletions

View File

@@ -428,6 +428,25 @@ class NPUModelRunner(GPUModelRunner):
def _use_aclgraph(self) -> bool:
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode == CompilationMode.VLLM_COMPILE and not self.model_config.enforce_eager
def _skip_all_reduce_acorss_dp_group(self) -> bool:
# NOTE: We can skip the all_reduce operation and avoid paading tokens
# to max_tokens_acrodd_dp in D nodes. In MoE models, we must ensure that
# num_tokens DOES NOT exceed mc2_tokens_capacity which means that moe_comm_method
# of each rank is MC2. For dense models, skipping all_reduce is not necessary
# since collective-communication is not time-consuming since dp_size in dense
# model deployments is always small and can be overlapped by async scheduling.
if not is_moe_model(self.vllm_config):
return False
if self.compilation_config.cudagraph_capture_sizes:
potential_max_num_tokens = self.compilation_config.max_cudagraph_capture_size
else:
potential_max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
# To ensure skipping all_reduce across dp group is valid, we need to ensure that
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
# nodes. So here we check whether recompute_scheduler_enable is True.
return self.is_kv_consumer and not self.in_profile_run and self.ascend_config.recompute_scheduler_enable and self._select_moe_comm_method(
potential_max_num_tokens) == MoECommType.MC2
def _sync_metadata_across_dp(
self, num_tokens: int,
with_prefill: bool) -> tuple[int, Optional[torch.Tensor], bool]:
@@ -439,6 +458,14 @@ class NPUModelRunner(GPUModelRunner):
# immediately once the other two flags are no longer needed.
if self.dp_size == 1:
return num_tokens, None, with_prefill
if self._skip_all_reduce_acorss_dp_group():
num_tokens_after_padding = torch.tensor([num_tokens] *
self.dp_size,
device="cpu",
dtype=torch.int32)
return num_tokens, num_tokens_after_padding, with_prefill
# Sync num_tokens, with_prefill across dp ranks
num_tokens_tensor = torch.tensor([
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)
@@ -1097,7 +1124,7 @@ class NPUModelRunner(GPUModelRunner):
attn_metadata[layer_name] = attn_metadata_i
if lmhead_tp_enable():
max_num_reqs_across_dp = maybe_padded_num_tokens if not with_prefill else self.max_num_reqs
max_num_reqs_across_dp = self.max_num_reqs * self.uniform_decode_query_len
logits_indices = nn.functional.pad(
logits_indices,
(0, max_num_reqs_across_dp - logits_indices.shape[0]))
@@ -2190,7 +2217,7 @@ class NPUModelRunner(GPUModelRunner):
need_dummy_logits = (not self.in_profile_run
and lmhead_tp_enable())
max_num_reqs_across_dp = num_tokens_padded if not with_prefill else max_num_reqs
max_num_reqs_across_dp = max_num_reqs * self.uniform_decode_query_len
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)