[Main] [Refactor] Enable MoECommMethod in Eager Mode (#2791)

### What this PR does / why we need it?
1. Replace prepare/finalize operation in fused_moe.py by
moe_comm_method.prepare()/finalize()
2. Replace unified_fused_experts by moe_comm_method.fused_experts() in
fused_moe.py/w8a8_dynamic.py/w4a8_dynamic.py
3. Add calling _select_moe_comm_method in spec-decode proposers.
4. Currently, w4a8_dynamic does not support gatherep, use all2allv
instead.
5. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
AllgatherEP switch is disabled in aclgraph/eager mode, just follow the
rules in modelrunner_v1._select_moe_comm_method()
### How was this patch tested?
e2e & ut


- vLLM version: v0.10.2
- vLLM main:
7f6f2c1182

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-16 11:06:00 +08:00
committed by GitHub
parent 0aba644633
commit 18ca7861f6
18 changed files with 523 additions and 596 deletions

View File

@@ -1663,7 +1663,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_connector_output=kv_connector_output,
)
def _select_moe_comm_method(self, num_tokens: int) -> str:
def _select_moe_comm_method(self, num_tokens: int,
with_prefill: bool) -> str:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
@@ -1687,6 +1688,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
str: The selected MoE communication method, either "allgather", "mc2", or "alltoall".
"""
soc_version = get_ascend_soc_version()
quant_type = getattr(self.vllm_config.model_config.hf_config,
'moe_quantize', None)
if not self.parallel_config.enable_expert_parallel:
moe_comm_method = "allgather"
@@ -1694,12 +1697,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if num_tokens <= self.mc2_tokens_capacity and self.parallel_config.world_size_across_dp >= 16:
moe_comm_method = "mc2"
else:
moe_comm_method = "allgather"
if quant_type == "w4a8_dynamic":
moe_comm_method = "alltoall"
else:
moe_comm_method = "allgather"
elif soc_version in {AscendSocVersion.A3}:
moe_comm_method = "mc2" if num_tokens <= self.mc2_tokens_capacity else "alltoall"
else:
raise ValueError(f"Unsupported soc_version: {soc_version}")
if moe_comm_method == "allgather" and with_prefill:
moe_comm_method = "naivemulticast"
if is_global_first_rank():
logger.debug(f"num_tokens: {num_tokens}, "
f"moe_comm_method: {moe_comm_method}")
@@ -1728,7 +1738,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors) = (self._prepare_inputs(
scheduler_output, intermediate_tensors))
moe_comm_method = self._select_moe_comm_method(num_input_tokens)
moe_comm_method = self._select_moe_comm_method(num_input_tokens,
self.with_prefill)
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
uniform_decode=False)
@@ -2100,7 +2111,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
moe_comm_method = self._select_moe_comm_method(num_tokens)
moe_comm_method = self._select_moe_comm_method(num_tokens,
with_prefill)
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.seperate_routine(). This means that we are using