[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface (#6811)

### What this PR does / why we need it?
[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface into
eagle_proposer.py

This pull request significantly refactors the speculative decoding
mechanism by merging Parallel Context Processing (PCP) and Multi-Token
Prediction (MTP) functionalities directly into the eagle_proposer.py.
The changes aim to enhance the efficiency and correctness of distributed
speculative decoding, particularly by enabling the Eagle feature to work
seamlessly with the disable_padded interface. This involves detailed
adjustments to attention metadata, input/output processing, and state
management to ensure proper operation in parallel environments.

1. The PCP and MTP features are migrated to the eagle_proposer.py
2. The Eagle and PCP features are integrated
3. Enable the eagle feature to use the disable_padded interface

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Tests and UT

- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2026-02-27 16:06:56 +08:00
committed by GitHub
parent e4458b2d2b
commit c13d90b766
6 changed files with 245 additions and 60 deletions

View File

@@ -159,6 +159,7 @@ class AscendMetadata:
num_decode_tokens: int = 0
num_prefills: int = 0
num_decodes: int = 0
num_decodes_flatten: int = 0
# The sequence length per sequence. Sequence length means the computed
# tokens + new tokens (is None if it is a decoding).

View File

@@ -117,6 +117,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
block_table = common_attn_metadata.block_table_tensor
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
self.num_decodes_flatten = query_lens[:num_decodes].sum().item()
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
@@ -146,7 +147,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
pcp_size = get_pcp_group().world_size
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
local_context_lens_allranks = (
torch.tensor(num_computed_tokens_of_pcp_dcp)[num_decodes:num_reqs]
torch.tensor(num_computed_tokens_of_pcp_dcp)[self.num_decodes_flatten :]
.to(self.device)
.to(dtype=torch.int32)
)
@@ -214,23 +215,24 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
prefill_metadata = AscendMetadataForPrefill(
pcp_metadata=pcp_metadata,
chunked_context=chunked_context_metadata,
block_tables=block_table[num_decodes:],
block_tables=block_table[self.num_decodes_flatten :, ...],
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0),
)
if num_decodes > 0:
num_computed_tokens_array = np.array(num_computed_tokens_of_pcp_dcp)
num_computed_tokens_array = num_computed_tokens_array[:num_decodes]
num_computed_tokens_array = num_computed_tokens_array[: self.num_decodes_flatten]
# TODO: numpy array mode of the shared memory is used to improve performance
decode_metadata = AscendMetadataForDecode(
num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
block_tables=block_table[:num_decodes],
block_tables=block_table[: self.num_decodes_flatten],
)
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
num_decode_tokens=num_decode_tokens,
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
num_decodes_flatten=self.num_decodes_flatten,
block_tables=block_table,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
@@ -550,7 +552,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
"actual_seq_lengths_kv": attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[
:, self.pcp_rank, self.dcp_rank
],
"actual_seq_lengths": attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decodes],
"actual_seq_lengths": torch.arange(attn_metadata.num_decodes_flatten) + 1,
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()