[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface (#6811)
### What this PR does / why we need it?
[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface into
eagle_proposer.py
This pull request significantly refactors the speculative decoding
mechanism by merging Parallel Context Processing (PCP) and Multi-Token
Prediction (MTP) functionalities directly into the eagle_proposer.py.
The changes aim to enhance the efficiency and correctness of distributed
speculative decoding, particularly by enabling the Eagle feature to work
seamlessly with the disable_padded interface. This involves detailed
adjustments to attention metadata, input/output processing, and state
management to ensure proper operation in parallel environments.
1. The PCP and MTP features are migrated to the eagle_proposer.py
2. The Eagle and PCP features are integrated
3. Enable the eagle feature to use the disable_padded interface
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Tests and UT
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
---------
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
@@ -561,7 +561,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
dtype=np.int32,
|
||||
)
|
||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens)
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
# Determine if it's a splitfuse batch
|
||||
with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding]
|
||||
@@ -800,7 +799,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
# Speculative decoding.
|
||||
elif np.all(num_valid_tokens == 1):
|
||||
if self.speculative_config and self.speculative_config.method == "mtp":
|
||||
if self.speculative_config:
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
else:
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
@@ -809,6 +808,14 @@ class NPUModelRunner(GPUModelRunner):
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
else:
|
||||
attn_state = AscendAttentionState.PrefillCacheHit
|
||||
|
||||
# For the overlay of the PCP feature and the eagle3, attn_state needs to be recovered
|
||||
# TODO: Resolved the conflict between the sunset of attn_state and the PCP that requires this interface.
|
||||
if attn_state == AscendAttentionState.SpecDecoding and self.speculative_config.method != "mtp":
|
||||
self.attn_state = AscendAttentionState.ChunkedPrefill # type: ignore
|
||||
else:
|
||||
self.attn_state = attn_state # type: ignore
|
||||
|
||||
return attn_state
|
||||
|
||||
def _calc_spec_decode_metadata(
|
||||
@@ -977,6 +984,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
|
||||
target_positions = self._get_positions(num_scheduled_tokens)
|
||||
target_hidden_states = hidden_states
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
|
||||
)
|
||||
else:
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
@@ -1014,6 +1025,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
target_token_ids = input_ids_pcp_full[token_indices]
|
||||
target_positions = positions
|
||||
target_hidden_states = hidden_states
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_token_ids = self.input_ids.gpu[token_indices]
|
||||
target_positions = self._get_positions(token_indices)
|
||||
@@ -1260,13 +1273,18 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs
|
||||
)
|
||||
with record_function_or_nullcontext("post process"):
|
||||
aux_hidden_states = None
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
if self.pcp_size > 1:
|
||||
# NOTE we must `slice` hidden_states because pcp_allgather_restore_idx
|
||||
# ignores the padding from CUDA Graph.
|
||||
hidden_states = self.pcp_manager.get_restore_hidden_states(hidden_states)
|
||||
aux_hidden_states = None
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
if aux_hidden_states is not None:
|
||||
aux_hidden_states = [
|
||||
self.pcp_manager.get_restore_hidden_states(aux_hidden_states_pcp)
|
||||
for aux_hidden_states_pcp in aux_hidden_states
|
||||
]
|
||||
|
||||
if not self.broadcast_pp_output:
|
||||
# Common case.
|
||||
|
||||
Reference in New Issue
Block a user