[Spec-Decode] Fix spec decode proposer in 0.18.0 (#7544)

### What this PR does / why we need it?
As the vllm-ascend main doesn't maintain v0.17.0 now, we'd just apply
the single branch in eagle proposer. Otherwise it will raise error in
v0.18.0

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
CI passed with existing test.

- vLLM version: v0.18.0
- vLLM main:
8b6325758c

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2026-03-23 15:39:24 +08:00
committed by GitHub
parent 6b7d9b76f1
commit 9e2878065a

View File

@@ -46,7 +46,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params
from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled
# Currently we will fix block size to a small one since `num_reqs` can't be too large
_PREPARE_INPUTS_BLOCK_SIZE = 4
@@ -390,11 +390,8 @@ class SpecDecodeBaseProposer(EagleProposer):
: num_reqs * self.decode_threshold
]
if vllm_version_is("0.17.0"):
assert len(self.draft_attn_groups) > 0
builder = self.draft_attn_groups[0].get_metadata_builder()
else:
builder = self.runner.attn_groups[0][0].get_metadata_builder()
assert len(self.draft_attn_groups) > 0
builder = self.draft_attn_groups[0].get_metadata_builder()
# update the tensor's address for each step.
for draft_step in range(self.num_speculative_tokens):
common_attn_metadata = self.shallow_copy_metadata(common_attn_metadata)
@@ -558,11 +555,8 @@ class SpecDecodeBaseProposer(EagleProposer):
common_attn_metadata.slot_mapping = self.slot_mapping_group[0]
common_attn_metadata.num_input_tokens = num_input_tokens
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
if vllm_version_is("0.17.0"):
assert len(self.draft_attn_groups) > 0
builder = self.draft_attn_groups[0].get_metadata_builder()
else:
builder = self.runner.attn_groups[0][0].get_metadata_builder()
assert len(self.draft_attn_groups) > 0
builder = self.draft_attn_groups[0].get_metadata_builder()
attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model())
if self.uses_mrope:
@@ -1549,11 +1543,8 @@ class SpecDecodeBaseProposer(EagleProposer):
# update full-graph params for one spec token
def _update_full_graph_params(self, forward_context, num_tokens, draft_attn_metadatas=None):
if vllm_version_is("0.17.0"):
assert len(self.draft_attn_groups) > 0
attn_backend = self.draft_attn_groups[0].backend
else:
attn_backend = self.runner.attn_backend
assert len(self.draft_attn_groups) > 0
attn_backend = self.draft_attn_groups[0].backend
update_full_graph_params(
attn_backend,
self.update_stream,