[Main] [Refactor] Enable MoECommMethod in Eager Mode (#2791)
### What this PR does / why we need it?
1. Replace prepare/finalize operation in fused_moe.py by
moe_comm_method.prepare()/finalize()
2. Replace unified_fused_experts by moe_comm_method.fused_experts() in
fused_moe.py/w8a8_dynamic.py/w4a8_dynamic.py
3. Add calling _select_moe_comm_method in spec-decode proposers.
4. Currently, w4a8_dynamic does not support gatherep, use all2allv
instead.
5. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
AllgatherEP switch is disabled in aclgraph/eager mode, just follow the
rules in modelrunner_v1._select_moe_comm_method()
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.2
- vLLM main:
7f6f2c1182
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -117,8 +117,11 @@ class EagleProposer(Proposer):
|
||||
skip_attn: bool = False,
|
||||
num_reqs: int = 0,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None):
|
||||
moe_comm_method = self.runner._select_moe_comm_method(
|
||||
num_tokens, with_prefill)
|
||||
with set_ascend_forward_context(None,
|
||||
self.vllm_config,
|
||||
moe_comm_method=moe_comm_method,
|
||||
num_tokens=num_tokens):
|
||||
self.model(
|
||||
input_ids=self.input_ids[:num_tokens],
|
||||
@@ -447,12 +450,20 @@ class EagleProposer(Proposer):
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
|
||||
with_prefill = attn_metadata.attn_state not in [
|
||||
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
|
||||
]
|
||||
moe_comm_method = self.runner._select_moe_comm_method(
|
||||
num_input_tokens, with_prefill)
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.positions[:num_tokens] = target_positions.to(device)
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
attn_metadata.block_tables = block_table.to(device)
|
||||
with set_ascend_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
moe_comm_method=moe_comm_method,
|
||||
num_tokens=num_input_tokens):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
@@ -483,6 +494,10 @@ class EagleProposer(Proposer):
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
|
||||
moe_comm_method = self.runner._select_moe_comm_method(
|
||||
input_batch_size, False)
|
||||
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.max_query_len = 1
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
@@ -553,6 +568,7 @@ class EagleProposer(Proposer):
|
||||
# Run the model.
|
||||
with set_ascend_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
moe_comm_method=moe_comm_method,
|
||||
num_tokens=input_batch_size):
|
||||
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
|
||||
@@ -112,6 +112,10 @@ class MtpProposer(Proposer):
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self.runner._sync_metadata_across_dp(num_tokens,
|
||||
with_prefill, False)
|
||||
|
||||
moe_comm_method = self.runner._select_moe_comm_method(
|
||||
num_tokens, with_prefill)
|
||||
|
||||
is_running_torchair = self.torchair_graph_enabled and \
|
||||
not with_prefill
|
||||
|
||||
@@ -142,6 +146,7 @@ class MtpProposer(Proposer):
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
moe_comm_method=moe_comm_method,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=0):
|
||||
if is_running_torchair:
|
||||
@@ -411,6 +416,9 @@ class MtpProposer(Proposer):
|
||||
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
||||
with_prefill = self.runner.with_prefill
|
||||
|
||||
moe_comm_method = self.runner._select_moe_comm_method(
|
||||
num_input_tokens, with_prefill)
|
||||
|
||||
for step in range(self.num_speculative_tokens):
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
@@ -419,6 +427,7 @@ class MtpProposer(Proposer):
|
||||
with_prefill=with_prefill,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||
moe_comm_method=moe_comm_method,
|
||||
in_profile_run=self.runner.in_profile_run,
|
||||
num_actual_tokens=num_tokens):
|
||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||
|
||||
Reference in New Issue
Block a user