Revert "[Refactor][EAGLE] 8/N delete mtp_proposer" (#7030)

Reverts vllm-project/vllm-ascend#7016
It breaks E2E test
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
This commit is contained in:
wangxiyuan
2026-03-06 11:24:05 +08:00
committed by GitHub
parent 8c2c82f3e1
commit 16c3b0b822
6 changed files with 931 additions and 19 deletions

View File

@@ -129,11 +129,7 @@ class AscendEagleProposer(EagleProposer):
self.use_cuda_graph = self.runner._use_aclgraph() and not self.speculative_config.enforce_eager
if self.method == "mtp":
self.use_cuda_graph = (
self.use_cuda_graph
and not self.use_async_scheduling
and not self.speculative_config.disable_padded_drafter_batch
)
self.use_cuda_graph = self.use_cuda_graph and not self.use_async_scheduling
# TODO: Remove it when the bug of fx-graph is solved
self.maybe_eager_context: AbstractContextManager[Any] = nullcontext()
@@ -344,8 +340,7 @@ class AscendEagleProposer(EagleProposer):
# Set the real slot_mapping.
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step]
attn_metadata_eagle = builder.build_for_graph_capture(
common_attn_metadata,
AscendAttentionState.SpecDecoding if self.method == "mtp" else AscendAttentionState.ChunkedPrefill,
common_attn_metadata, AscendAttentionState.ChunkedPrefill
)
per_layer_attn_metadata = dict()
for layer_name in self.attn_layer_names:
@@ -541,7 +536,7 @@ class AscendEagleProposer(EagleProposer):
slot_mapping_lens = common_attn_metadata.slot_mapping.shape[0]
self.slot_mapping_group[0][:slot_mapping_lens].copy_(common_attn_metadata.slot_mapping[:slot_mapping_lens])
self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1)
common_attn_metadata.slot_mapping = self.slot_mapping_group[0]
common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens]
common_attn_metadata.num_input_tokens = num_input_tokens
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
builder = self.runner.attn_groups[0][0].get_metadata_builder()
@@ -905,9 +900,7 @@ class AscendEagleProposer(EagleProposer):
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.decode_token_per_req = 1
common_attn_metadata.attn_state = (
AscendAttentionState.SpecDecoding if self.method == "mtp" else AscendAttentionState.ChunkedPrefill
)
common_attn_metadata.attn_state = AscendAttentionState.ChunkedPrefill
common_attn_metadata.graph_pad_size = -1
common_attn_metadata.num_input_tokens = input_batch_size
@@ -989,7 +982,7 @@ class AscendEagleProposer(EagleProposer):
self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32))
self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID)
# Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx]
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step]
common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]]
# Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore