[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface (#6811)
### What this PR does / why we need it?
[Refactor][EAGLE] 7/N Merged PCP and disable_padded interface into
eagle_proposer.py
This pull request significantly refactors the speculative decoding
mechanism by merging Parallel Context Processing (PCP) and Multi-Token
Prediction (MTP) functionalities directly into the eagle_proposer.py.
The changes aim to enhance the efficiency and correctness of distributed
speculative decoding, particularly by enabling the Eagle feature to work
seamlessly with the disable_padded interface. This involves detailed
adjustments to attention metadata, input/output processing, and state
management to ensure proper operation in parallel environments.
1. The PCP and MTP features are migrated to the eagle_proposer.py
2. The Eagle and PCP features are integrated
3. Enable the eagle feature to use the disable_padded interface
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Tests and UT
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
---------
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
@@ -117,6 +117,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
self.num_decodes_flatten = query_lens[:num_decodes].sum().item()
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
|
||||
long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata
|
||||
@@ -146,7 +147,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
pcp_size = get_pcp_group().world_size
|
||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||
local_context_lens_allranks = (
|
||||
torch.tensor(num_computed_tokens_of_pcp_dcp)[num_decodes:num_reqs]
|
||||
torch.tensor(num_computed_tokens_of_pcp_dcp)[self.num_decodes_flatten :]
|
||||
.to(self.device)
|
||||
.to(dtype=torch.int32)
|
||||
)
|
||||
@@ -214,23 +215,24 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
prefill_metadata = AscendMetadataForPrefill(
|
||||
pcp_metadata=pcp_metadata,
|
||||
chunked_context=chunked_context_metadata,
|
||||
block_tables=block_table[num_decodes:],
|
||||
block_tables=block_table[self.num_decodes_flatten :, ...],
|
||||
actual_seq_lengths_q=torch.cumsum(query_lens, dim=0),
|
||||
)
|
||||
|
||||
if num_decodes > 0:
|
||||
num_computed_tokens_array = np.array(num_computed_tokens_of_pcp_dcp)
|
||||
num_computed_tokens_array = num_computed_tokens_array[:num_decodes]
|
||||
num_computed_tokens_array = num_computed_tokens_array[: self.num_decodes_flatten]
|
||||
# TODO: numpy array mode of the shared memory is used to improve performance
|
||||
decode_metadata = AscendMetadataForDecode(
|
||||
num_computed_tokens_of_pcp_dcp=num_computed_tokens_array,
|
||||
block_tables=block_table[:num_decodes],
|
||||
block_tables=block_table[: self.num_decodes_flatten],
|
||||
)
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
num_decodes_flatten=self.num_decodes_flatten,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
@@ -550,7 +552,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
"actual_seq_lengths_kv": attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[
|
||||
:, self.pcp_rank, self.dcp_rank
|
||||
],
|
||||
"actual_seq_lengths": attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decodes],
|
||||
"actual_seq_lengths": torch.arange(attn_metadata.num_decodes_flatten) + 1,
|
||||
}
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
|
||||
Reference in New Issue
Block a user