[Refactor][EAGLE] 6/N route mtp to eagle except pcp/dcp+mtp (#6349)

### What this PR does / why we need it?

Overview: This pull request refactors speculative decoding for Eagle and
MTP proposers on Ascend hardware. It fixes a bug related to
draft_attn_metadatas being lost, migrates the lmhead feature, and adds
routing logic in MtpProposer.

Details:
1. Migrated the lmhead feature from mtp to eagle and normalized it in
eagle_proposer.
2. Fixed the bug where draft_attn_metadatas was lost after enabling
eagle mode in the merge graph.
3. Added the routing for pcp and disable padded drafter batch; in mtp
mode, if pcp and disable padded drafter batch are not enabled, the
normalized file eagle_proposer will be used.

RFC: https://github.com/vllm-project/vllm-ascend/issues/5467

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
ut and test

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2026-02-02 19:15:31 +08:00
committed by GitHub
parent c08364f761
commit 7932255c06
4 changed files with 90 additions and 24 deletions

View File

@@ -37,7 +37,16 @@ class MtpProposer(EagleProposer):
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None,
is_profile=False) -> None:
if (
self.pcp_size * self.dcp_size == 1
and not self.speculative_config.disable_padded_drafter_batch
):
super().dummy_run(
num_tokens, with_prefill, in_graph_capturing, num_reqs,
num_tokens_across_dp, aclgraph_runtime_mode, batch_descriptor,
dummy_compute_logits, is_profile
)
return
(
num_tokens,
num_tokens_across_dp,
@@ -151,6 +160,19 @@ class MtpProposer(EagleProposer):
scheduler_output: SchedulerOutput = None,
num_scheduled_tokens: int = 0,
) -> torch.Tensor:
if (
self.pcp_size * self.dcp_size == 1
and not self.speculative_config.disable_padded_drafter_batch
):
draft_token_ids = super()._propose(
target_token_ids, target_positions, target_hidden_states,
next_token_ids, last_token_indices, common_attn_metadata,
sampling_metadata, mm_embed_inputs, req_scheduled_tokens,
long_seq_metadata, num_prefill_reqs, num_decode_reqs,
scheduler_output, num_scheduled_tokens
)
return draft_token_ids
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]