[Feat] Support separate attention backend for target and draft model. (#7342)
### What this PR does / why we need it? This PR enables separate attention backend configuration for target and draft models in speculative decoding, decoupling the previously bound attention backend settings between the two models. It solves the compatibility issue where some draft models do not support the attention backend used by the target model, and allows users to select the optimal attention backend for each model individually to maximize inference performance. The change is fully backward compatible. --------- Signed-off-by: SidaoY <1024863041@qq.com>
This commit is contained in:
@@ -46,7 +46,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params
|
||||
from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled
|
||||
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is
|
||||
|
||||
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
||||
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
||||
@@ -217,6 +217,8 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
self.model.config.image_token_index = model.config.image_token_id
|
||||
elif self.get_model_name(model) == "PixtralForConditionalGeneration":
|
||||
self.model.config.image_token_index = model.config.vision_config.image_token_id
|
||||
elif self.get_model_name(model) == "KimiK25ForConditionalGeneration":
|
||||
self.model.config.image_token_index = model.config.media_placeholder_token_id
|
||||
else:
|
||||
self.model.config.image_token_index = model.config.image_token_index
|
||||
target_language_model = model.get_language_model()
|
||||
@@ -388,7 +390,11 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
: num_reqs * self.decode_threshold
|
||||
]
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
if vllm_version_is("0.17.0"):
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
builder = self.draft_attn_groups[0].get_metadata_builder()
|
||||
else:
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
# update the tensor's address for each step.
|
||||
for draft_step in range(self.num_speculative_tokens):
|
||||
common_attn_metadata = self.shallow_copy_metadata(common_attn_metadata)
|
||||
@@ -550,7 +556,11 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[0]
|
||||
common_attn_metadata.num_input_tokens = num_input_tokens
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
if vllm_version_is("0.17.0"):
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
builder = self.draft_attn_groups[0].get_metadata_builder()
|
||||
else:
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model())
|
||||
|
||||
if self.uses_mrope:
|
||||
|
||||
Reference in New Issue
Block a user