[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface (#6811)

### What this PR does / why we need it?
[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface into
eagle_proposer.py

This pull request significantly refactors the speculative decoding
mechanism by merging Parallel Context Processing (PCP) and Multi-Token
Prediction (MTP) functionalities directly into the eagle_proposer.py.
The changes aim to enhance the efficiency and correctness of distributed
speculative decoding, particularly by enabling the Eagle feature to work
seamlessly with the disable_padded interface. This involves detailed
adjustments to attention metadata, input/output processing, and state
management to ensure proper operation in parallel environments.

1. The PCP and MTP features are migrated to the eagle_proposer.py
2. The Eagle and PCP features are integrated
3. Enable the eagle feature to use the disable_padded interface

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Tests and UT

- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2026-02-27 16:06:56 +08:00
committed by GitHub
parent e4458b2d2b
commit c13d90b766
6 changed files with 245 additions and 60 deletions

View File

@@ -561,7 +561,6 @@ class NPUModelRunner(GPUModelRunner):
dtype=np.int32,
)
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens)
self.attn_state = attn_state # type: ignore
# Determine if it's a splitfuse batch
with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding]
@@ -800,7 +799,7 @@ class NPUModelRunner(GPUModelRunner):
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if self.speculative_config and self.speculative_config.method == "mtp":
if self.speculative_config:
attn_state = AscendAttentionState.SpecDecoding
else:
attn_state = AscendAttentionState.ChunkedPrefill
@@ -809,6 +808,14 @@ class NPUModelRunner(GPUModelRunner):
attn_state = AscendAttentionState.ChunkedPrefill
else:
attn_state = AscendAttentionState.PrefillCacheHit
# For the overlay of the PCP feature and the eagle3, attn_state needs to be recovered
# TODO: Resolved the conflict between the sunset of attn_state and the PCP that requires this interface.
if attn_state == AscendAttentionState.SpecDecoding and self.speculative_config.method != "mtp":
self.attn_state = AscendAttentionState.ChunkedPrefill # type: ignore
else:
self.attn_state = attn_state # type: ignore
return attn_state
def _calc_spec_decode_metadata(
@@ -977,6 +984,10 @@ class NPUModelRunner(GPUModelRunner):
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
target_positions = self._get_positions(num_scheduled_tokens)
target_hidden_states = hidden_states
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
)
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
@@ -1014,6 +1025,8 @@ class NPUModelRunner(GPUModelRunner):
target_token_ids = input_ids_pcp_full[token_indices]
target_positions = positions
target_hidden_states = hidden_states
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_token_ids = self.input_ids.gpu[token_indices]
target_positions = self._get_positions(token_indices)
@@ -1260,13 +1273,18 @@ class NPUModelRunner(GPUModelRunner):
num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs
)
with record_function_or_nullcontext("post process"):
aux_hidden_states = None
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = hidden_states
if self.pcp_size > 1:
# NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
# ignores the padding from CUDA Graph.
hidden_states = self.pcp_manager.get_restore_hidden_states(hidden_states)
aux_hidden_states = None
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = hidden_states
if aux_hidden_states is not None:
aux_hidden_states = [
self.pcp_manager.get_restore_hidden_states(aux_hidden_states_pcp)
for aux_hidden_states_pcp in aux_hidden_states
]
if not self.broadcast_pp_output:
# Common case.