[Main] [Refactor] Enable MoECommMethod in Eager Mode (#2791)

### What this PR does / why we need it?
1. Replace prepare/finalize operation in fused_moe.py by
moe_comm_method.prepare()/finalize()
2. Replace unified_fused_experts by moe_comm_method.fused_experts() in
fused_moe.py/w8a8_dynamic.py/w4a8_dynamic.py
3. Add calling _select_moe_comm_method in spec-decode proposers.
4. Currently, w4a8_dynamic does not support gatherep, use all2allv
instead.
5. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
AllgatherEP switch is disabled in aclgraph/eager mode, just follow the
rules in modelrunner_v1._select_moe_comm_method()
### How was this patch tested?
e2e & ut


- vLLM version: v0.10.2
- vLLM main:
7f6f2c1182

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-16 11:06:00 +08:00
committed by GitHub
parent 0aba644633
commit 18ca7861f6
18 changed files with 523 additions and 596 deletions

View File

@@ -117,8 +117,11 @@ class EagleProposer(Proposer):
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None):
moe_comm_method = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
with set_ascend_forward_context(None,
self.vllm_config,
moe_comm_method=moe_comm_method,
num_tokens=num_tokens):
self.model(
input_ids=self.input_ids[:num_tokens],
@@ -447,12 +450,20 @@ class EagleProposer(Proposer):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
with_prefill = attn_metadata.attn_state not in [
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
moe_comm_method = self.runner._select_moe_comm_method(
num_input_tokens, with_prefill)
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions.to(device)
self.hidden_states[:num_tokens] = target_hidden_states
attn_metadata.block_tables = block_table.to(device)
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
moe_comm_method=moe_comm_method,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
@@ -483,6 +494,10 @@ class EagleProposer(Proposer):
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
moe_comm_method = self.runner._select_moe_comm_method(
input_batch_size, False)
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
@@ -553,6 +568,7 @@ class EagleProposer(Proposer):
# Run the model.
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
moe_comm_method=moe_comm_method,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(

View File

@@ -112,6 +112,10 @@ class MtpProposer(Proposer):
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._sync_metadata_across_dp(num_tokens,
with_prefill, False)
moe_comm_method = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
is_running_torchair = self.torchair_graph_enabled and \
not with_prefill
@@ -142,6 +146,7 @@ class MtpProposer(Proposer):
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_method=moe_comm_method,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0):
if is_running_torchair:
@@ -411,6 +416,9 @@ class MtpProposer(Proposer):
num_tokens_across_dp = self.runner.num_tokens_across_dp
with_prefill = self.runner.with_prefill
moe_comm_method = self.runner._select_moe_comm_method(
num_input_tokens, with_prefill)
for step in range(self.num_speculative_tokens):
with set_ascend_forward_context(
attn_metadata,
@@ -419,6 +427,7 @@ class MtpProposer(Proposer):
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_method=moe_comm_method,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'):