[Refactor][EAGLE] 8/N delete mtp_proposer (#7016)
### What this PR does / why we need it?
This PR aims to delete mtp_proposer. By fixing a bug in both dsv32 and
glm5, now it should be ok to remove mtp_proposer. The bug is actually
about unnecessary slicing of `slot_mapping`.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -129,7 +129,11 @@ class AscendEagleProposer(EagleProposer):
|
||||
|
||||
self.use_cuda_graph = self.runner._use_aclgraph() and not self.speculative_config.enforce_eager
|
||||
if self.method == "mtp":
|
||||
self.use_cuda_graph = self.use_cuda_graph and not self.use_async_scheduling
|
||||
self.use_cuda_graph = (
|
||||
self.use_cuda_graph
|
||||
and not self.use_async_scheduling
|
||||
and not self.speculative_config.disable_padded_drafter_batch
|
||||
)
|
||||
|
||||
# TODO: Remove it when the bug of fx-graph is solved
|
||||
self.maybe_eager_context: AbstractContextManager[Any] = nullcontext()
|
||||
@@ -340,7 +344,8 @@ class AscendEagleProposer(EagleProposer):
|
||||
# Set the real slot_mapping.
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step]
|
||||
attn_metadata_eagle = builder.build_for_graph_capture(
|
||||
common_attn_metadata, AscendAttentionState.ChunkedPrefill
|
||||
common_attn_metadata,
|
||||
AscendAttentionState.SpecDecoding if self.method == "mtp" else AscendAttentionState.ChunkedPrefill,
|
||||
)
|
||||
per_layer_attn_metadata = dict()
|
||||
for layer_name in self.attn_layer_names:
|
||||
@@ -536,7 +541,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
slot_mapping_lens = common_attn_metadata.slot_mapping.shape[0]
|
||||
self.slot_mapping_group[0][:slot_mapping_lens].copy_(common_attn_metadata.slot_mapping[:slot_mapping_lens])
|
||||
self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1)
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens]
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[0]
|
||||
common_attn_metadata.num_input_tokens = num_input_tokens
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
@@ -900,7 +905,9 @@ class AscendEagleProposer(EagleProposer):
|
||||
common_attn_metadata.num_actual_tokens = batch_size
|
||||
common_attn_metadata.max_query_len = 1
|
||||
common_attn_metadata.decode_token_per_req = 1
|
||||
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
|
||||
common_attn_metadata.attn_state = (
|
||||
AscendAttentionState.SpecDecoding if self.method == "mtp" else AscendAttentionState.ChunkedPrefill
|
||||
)
|
||||
common_attn_metadata.graph_pad_size = -1
|
||||
common_attn_metadata.num_input_tokens = input_batch_size
|
||||
|
||||
@@ -982,7 +989,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32))
|
||||
self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID)
|
||||
# Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx]
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]]
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step]
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user