[Feature] Enhance all-reduce skipping logic for MoE models in NPUModelRunner (#5329)

Besides enabling `recompute_scheduler_enable`, we can skip all_reduce
when max_num_batched_tokens is below mc2's requirement.

- vLLM version: release/v0.13.0
- vLLM main:
bc0a5a0c08

---------

Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
Jade Zheng
2025-12-26 17:39:44 +08:00
committed by GitHub
parent 06732dbf5b
commit 0dfdfa9526

View File

@@ -398,24 +398,41 @@ class NPUModelRunner(GPUModelRunner):
return self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE and self.compilation_config.mode == CompilationMode.VLLM_COMPILE and not self.model_config.enforce_eager
def _skip_all_reduce_acorss_dp_group(self) -> bool:
# NOTE: We can skip the all_reduce operation and avoid paading tokens
# to max_tokens_acrodd_dp in D nodes. In MoE models, we must ensure that
# num_tokens DOES NOT exceed mc2_tokens_capacity which means that moe_comm_method
# of each rank is MC2. For dense models, skipping all_reduce is not necessary
# since collective-communication is not time-consuming since dp_size in dense
# model deployments is always small and can be overlapped by async scheduling.
if not is_moe_model(self.vllm_config):
"""
Decide whether to skip the all-reduce across the data-parallel (DP) group.
Skipping is only applicable for MoE models and only on ranks that act as
KV consumers. We skip the DP all-reduce when either:
- Both the prefill and decode communication methods are MC2 (or FUSED_MC2), or
- Decode requires MC2 and ascend_config.recompute_scheduler_enable is True.
"""
# Only applicable to MoE models and KV consumer ranks.
if not is_moe_model(self.vllm_config) or not self.is_kv_consumer:
return False
def needs_mc2(num_tokens: int) -> bool:
return select_moe_comm_method(num_tokens, self.vllm_config) in {
MoECommType.MC2, MoECommType.FUSED_MC2
}
# Determine whether decode must use MC2. Use max cudagraph capture size
# if available, otherwise use the maximal uniform decode token count.
if self.compilation_config.cudagraph_capture_sizes:
potential_max_num_tokens = self.compilation_config.max_cudagraph_capture_size
potential_max_tokens = self.compilation_config.max_cudagraph_capture_size
else:
potential_max_num_tokens = self.max_num_reqs * self.uniform_decode_query_len
# To ensure skipping all_reduce across dp group is valid, we need to ensure that
# moe_comm_method of each rank is MC2 and recomputation would never happen in D
# nodes. So here we check whether recompute_scheduler_enable is True.
return self.is_kv_consumer and self.ascend_config.recompute_scheduler_enable and select_moe_comm_method(
potential_max_num_tokens,
self.vllm_config) in {MoECommType.MC2, MoECommType.FUSED_MC2}
potential_max_tokens = self.max_num_reqs * self.uniform_decode_query_len
decode_must_use_mc2 = needs_mc2(potential_max_tokens)
# For prefill, use the scheduler's max_num_batched_tokens for a single
# batch.
prefill_must_use_mc2 = needs_mc2(
self.vllm_config.scheduler_config.max_num_batched_tokens)
# Skip all-reduce if decode requires MC2 and either prefill also
# requires MC2 or recompute-based scheduler is enabled.
return decode_must_use_mc2 and (
prefill_must_use_mc2
or self.ascend_config.recompute_scheduler_enable)
def _sync_metadata_across_dp(
self, num_tokens: int,