[Refactor][EAGLE] 6/N route mtp to eagle except pcp/dcp+mtp (#6349)
### What this PR does / why we need it?
Overview: This pull request refactors speculative decoding for Eagle and
MTP proposers on Ascend hardware. It fixes a bug related to
draft_attn_metadatas being lost, migrates the lmhead feature, and adds
routing logic in MtpProposer.
Details:
1. Migrated the lmhead feature from mtp to eagle and normalized it in
eagle_proposer.
2. Fixed the bug where draft_attn_metadatas was lost after enabling
eagle mode in the merge graph.
3. Added the routing for pcp and disable padded drafter batch; in mtp
mode, if pcp and disable padded drafter batch are not enabled, the
normalized file eagle_proposer will be used.
RFC: https://github.com/vllm-project/vllm-ascend/issues/5467
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
ut and test
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
@@ -37,7 +37,16 @@ class MtpProposer(EagleProposer):
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None,
|
||||
is_profile=False) -> None:
|
||||
|
||||
if (
|
||||
self.pcp_size * self.dcp_size == 1
|
||||
and not self.speculative_config.disable_padded_drafter_batch
|
||||
):
|
||||
super().dummy_run(
|
||||
num_tokens, with_prefill, in_graph_capturing, num_reqs,
|
||||
num_tokens_across_dp, aclgraph_runtime_mode, batch_descriptor,
|
||||
dummy_compute_logits, is_profile
|
||||
)
|
||||
return
|
||||
(
|
||||
num_tokens,
|
||||
num_tokens_across_dp,
|
||||
@@ -151,6 +160,19 @@ class MtpProposer(EagleProposer):
|
||||
scheduler_output: SchedulerOutput = None,
|
||||
num_scheduled_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
self.pcp_size * self.dcp_size == 1
|
||||
and not self.speculative_config.disable_padded_drafter_batch
|
||||
):
|
||||
draft_token_ids = super()._propose(
|
||||
target_token_ids, target_positions, target_hidden_states,
|
||||
next_token_ids, last_token_indices, common_attn_metadata,
|
||||
sampling_metadata, mm_embed_inputs, req_scheduled_tokens,
|
||||
long_seq_metadata, num_prefill_reqs, num_decode_reqs,
|
||||
scheduler_output, num_scheduled_tokens
|
||||
)
|
||||
return draft_token_ids
|
||||
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user