[Feat] Support separate attention backend for target and draft model. (#7342)

### What this PR does / why we need it?
This PR enables separate attention backend configuration for target and
draft models in speculative decoding, decoupling the previously bound
attention backend settings between the two models.

It solves the compatibility issue where some draft models do not support
the attention backend used by the target model, and allows users to
select the optimal attention backend for each model individually to
maximize inference performance. The change is fully backward compatible.
---------
Signed-off-by: SidaoY <1024863041@qq.com>
This commit is contained in:
HongtaoYang
2026-03-21 10:48:01 +08:00
committed by GitHub
parent 88d03a783f
commit 80a4265717
3 changed files with 177 additions and 49 deletions

View File

@@ -46,7 +46,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params
from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is
# Currently we will fix block size to a small one since `num_reqs` can't be too large
_PREPARE_INPUTS_BLOCK_SIZE = 4
@@ -217,6 +217,8 @@ class SpecDecodeBaseProposer(EagleProposer):
self.model.config.image_token_index = model.config.image_token_id
elif self.get_model_name(model) == "PixtralForConditionalGeneration":
self.model.config.image_token_index = model.config.vision_config.image_token_id
elif self.get_model_name(model) == "KimiK25ForConditionalGeneration":
self.model.config.image_token_index = model.config.media_placeholder_token_id
else:
self.model.config.image_token_index = model.config.image_token_index
target_language_model = model.get_language_model()
@@ -388,7 +390,11 @@ class SpecDecodeBaseProposer(EagleProposer):
: num_reqs * self.decode_threshold
]
builder = self.runner.attn_groups[0][0].get_metadata_builder()
if vllm_version_is("0.17.0"):
assert len(self.draft_attn_groups) > 0
builder = self.draft_attn_groups[0].get_metadata_builder()
else:
builder = self.runner.attn_groups[0][0].get_metadata_builder()
# update the tensor's address for each step.
for draft_step in range(self.num_speculative_tokens):
common_attn_metadata = self.shallow_copy_metadata(common_attn_metadata)
@@ -550,7 +556,11 @@ class SpecDecodeBaseProposer(EagleProposer):
common_attn_metadata.slot_mapping = self.slot_mapping_group[0]
common_attn_metadata.num_input_tokens = num_input_tokens
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
builder = self.runner.attn_groups[0][0].get_metadata_builder()
if vllm_version_is("0.17.0"):
assert len(self.draft_attn_groups) > 0
builder = self.draft_attn_groups[0].get_metadata_builder()
else:
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model())
if self.uses_mrope: