[Feat] Support separate attention backend for target and draft model. (#7342)

### What this PR does / why we need it?
This PR enables separate attention backend configuration for target and
draft models in speculative decoding, decoupling the previously bound
attention backend settings between the two models.

It solves the compatibility issue where some draft models do not support
the attention backend used by the target model, and allows users to
select the optimal attention backend for each model individually to
maximize inference performance. The change is fully backward compatible.
---------
Signed-off-by: SidaoY <1024863041@qq.com>
This commit is contained in:
HongtaoYang
2026-03-21 10:48:01 +08:00
committed by GitHub
parent 88d03a783f
commit 80a4265717
3 changed files with 177 additions and 49 deletions

View File

@@ -0,0 +1,89 @@
import unittest
from types import SimpleNamespace
from unittest.mock import MagicMock
import torch
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
class TestNPUModelRunnerKVCache(unittest.TestCase):
def _build_runner(self):
runner = NPUModelRunner.__new__(NPUModelRunner)
runner.device = torch.device("cpu")
runner.use_sparse = False
runner.use_sparse_c8_indexer = False
runner.use_hybrid_blocks = False
runner.hybrid_with_attn_and_mamba = False
runner.runner_only_attn_layers = set()
runner.is_kv_consumer = False
runner.vllm_config = MagicMock()
runner.vllm_config.kv_transfer_config = None
runner.model_config = MagicMock()
runner.model_config.use_mla = True
backend = MagicMock()
backend.get_kv_cache_shape.side_effect = lambda num_blocks, block_size, num_kv_heads, head_size: (
2,
num_blocks,
block_size,
num_kv_heads,
head_size,
)
runner.attn_backend = backend
return runner
def test_allocate_kv_cache_uses_layer_spec_for_draft_gqa(self):
runner = self._build_runner()
kv_cache_spec = FullAttentionSpec(
block_size=16,
num_kv_heads=8,
head_size=64,
head_size_v=64,
dtype=torch.float16,
)
kv_cache_config = KVCacheConfig(
num_blocks=2,
kv_cache_tensors=[KVCacheTensor(size=kv_cache_spec.page_size_bytes * 2, shared_by=["draft_attn"])],
kv_cache_groups=[KVCacheGroupSpec(layer_names=["draft_attn"], kv_cache_spec=kv_cache_spec)],
)
kv_cache_raw_tensors = runner._allocate_kv_cache_tensors(kv_cache_config)
k_cache_raw, v_cache_raw = kv_cache_raw_tensors["draft_attn"]
self.assertEqual(k_cache_raw.numel(), kv_cache_spec.page_size_bytes)
self.assertEqual(v_cache_raw.numel(), kv_cache_spec.page_size_bytes)
def test_reshape_kv_cache_uses_layer_spec_for_draft_gqa(self):
runner = self._build_runner()
kv_cache_spec = FullAttentionSpec(
block_size=16,
num_kv_heads=8,
head_size=64,
head_size_v=64,
dtype=torch.float16,
)
kv_cache_config = KVCacheConfig(
num_blocks=2,
kv_cache_tensors=[KVCacheTensor(size=kv_cache_spec.page_size_bytes * 2, shared_by=["draft_attn"])],
kv_cache_groups=[KVCacheGroupSpec(layer_names=["draft_attn"], kv_cache_spec=kv_cache_spec)],
)
kv_cache_raw_tensors = runner._allocate_kv_cache_tensors(kv_cache_config)
runner._kv_cache_spec_attn_group_iterator = lambda: [
SimpleNamespace(
kv_cache_spec=kv_cache_spec,
backend=runner.attn_backend,
layer_names=["draft_attn"],
)
]
kv_caches = runner._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors)
k_cache, v_cache = kv_caches["draft_attn"]
self.assertEqual(k_cache.shape, (2, 16, 8, 64))
self.assertEqual(v_cache.shape, (2, 16, 8, 64))
if __name__ == "__main__":
unittest.main()

View File

@@ -46,7 +46,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params
from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is
# Currently we will fix block size to a small one since `num_reqs` can't be too large
_PREPARE_INPUTS_BLOCK_SIZE = 4
@@ -217,6 +217,8 @@ class SpecDecodeBaseProposer(EagleProposer):
self.model.config.image_token_index = model.config.image_token_id
elif self.get_model_name(model) == "PixtralForConditionalGeneration":
self.model.config.image_token_index = model.config.vision_config.image_token_id
elif self.get_model_name(model) == "KimiK25ForConditionalGeneration":
self.model.config.image_token_index = model.config.media_placeholder_token_id
else:
self.model.config.image_token_index = model.config.image_token_index
target_language_model = model.get_language_model()
@@ -388,7 +390,11 @@ class SpecDecodeBaseProposer(EagleProposer):
: num_reqs * self.decode_threshold
]
builder = self.runner.attn_groups[0][0].get_metadata_builder()
if vllm_version_is("0.17.0"):
assert len(self.draft_attn_groups) > 0
builder = self.draft_attn_groups[0].get_metadata_builder()
else:
builder = self.runner.attn_groups[0][0].get_metadata_builder()
# update the tensor's address for each step.
for draft_step in range(self.num_speculative_tokens):
common_attn_metadata = self.shallow_copy_metadata(common_attn_metadata)
@@ -550,7 +556,11 @@ class SpecDecodeBaseProposer(EagleProposer):
common_attn_metadata.slot_mapping = self.slot_mapping_group[0]
common_attn_metadata.num_input_tokens = num_input_tokens
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
builder = self.runner.attn_groups[0][0].get_metadata_builder()
if vllm_version_is("0.17.0"):
assert len(self.draft_attn_groups) > 0
builder = self.draft_attn_groups[0].get_metadata_builder()
else:
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model())
if self.uses_mrope:

View File

@@ -2657,6 +2657,34 @@ class NPUModelRunner(GPUModelRunner):
bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches, num_attn_module)
return kv_caches
def _get_layer_kv_cache_specs(self, kv_cache_config: KVCacheConfig) -> dict[str, KVCacheSpec]:
layer_kv_cache_spec: dict[str, KVCacheSpec] = {}
for group_kv_cache_spec in kv_cache_config.kv_cache_groups:
group_spec = group_kv_cache_spec.kv_cache_spec
for layer_name in group_kv_cache_spec.layer_names:
if isinstance(group_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec[layer_name] = group_spec.kv_cache_specs[layer_name]
else:
layer_kv_cache_spec[layer_name] = group_spec
return layer_kv_cache_spec
def _get_attention_kv_cache_dims(self, layer_name: str, kv_cache_spec: AttentionSpec) -> tuple[int, int]:
if isinstance(kv_cache_spec, MLAAttentionSpec):
attn_layers = get_layers_from_vllm_config(
self.vllm_config,
AttentionLayerBase,
[layer_name],
)
attn_layer = attn_layers[layer_name]
if not isinstance(attn_layer, MLAAttention):
raise TypeError(
f"Expected MLAAttention layer for {layer_name}, got {type(attn_layer).__name__}."
)
return attn_layer.kv_lora_rank, attn_layer.qk_rope_head_dim
head_size_v = kv_cache_spec.head_size_v if hasattr(kv_cache_spec, "head_size_v") else kv_cache_spec.head_size
return kv_cache_spec.head_size, head_size_v
def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Initializes the KV cache buffer with the correct size. The buffer needs
@@ -2677,10 +2705,7 @@ class NPUModelRunner(GPUModelRunner):
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None | None] = {}
# prefill disaggregation need the addr of cache tensor be aligned with 2M
alignment = 2 * 1024 * 1024
layer_kv_cache_spec: dict[str, KVCacheSpec] = {}
for group_kv_cache_spec in kv_cache_config.kv_cache_groups:
for layer_name in group_kv_cache_spec.layer_names:
layer_kv_cache_spec[layer_name] = group_kv_cache_spec.kv_cache_spec
layer_kv_cache_spec = self._get_layer_kv_cache_specs(kv_cache_config)
# If some tensors are shared by linear layers and attention layers,
# the same tensor format must be maintained even if some layers
# have only linear or attention layers, for example, the mtp layer.
@@ -2715,11 +2740,10 @@ class NPUModelRunner(GPUModelRunner):
# as it only support the 0-dim of kv_cache is `num_blocks`.
# For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim
# and rope head dim.
if not self.model_config.use_mla:
# for non-mla model, use FullAttentionSpec
k_tensor_split_factor = 2.0
v_tensor_split_factor = 2.0
elif self.use_sparse:
current_kv_cache_spec = layer_kv_cache_spec[layer_name]
assert isinstance(current_kv_cache_spec, AttentionSpec)
if self.use_sparse:
# for deepseek v3.2, we split the kv cache according to the corresponding ratio
kv_cache_spec = layer_kv_cache_spec[layer_name]
sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio
@@ -2728,10 +2752,11 @@ class NPUModelRunner(GPUModelRunner):
dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2]
dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3]
else:
# for other deepseek models, use MLAAttentionSpec
k_dim, v_dim = self._get_attention_kv_cache_dims(layer_name, current_kv_cache_spec)
assert k_dim > 0 and v_dim > 0
kv_head_dim_list = [
self.model_config.hf_text_config.kv_lora_rank,
self.model_config.hf_text_config.qk_rope_head_dim,
k_dim,
v_dim,
]
if self.is_kv_consumer and self.vllm_config.quant_config is not None:
k_tensor_split_factor, v_tensor_split_factor = (
@@ -2819,20 +2844,18 @@ class NPUModelRunner(GPUModelRunner):
corresponding memory buffer for KV cache.
"""
kv_caches: dict[str, torch.Tensor] = {}
layer_kv_cache_spec = {}
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
layer_kv_cache_spec[layer_name] = group.kv_cache_spec
layer_kv_cache_spec = self._get_layer_kv_cache_specs(kv_cache_config)
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
current_kv_cache_spec = layer_kv_cache_spec[layer_name]
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
if isinstance(kv_cache_spec, AttentionSpec):
if isinstance(current_kv_cache_spec, AttentionSpec):
if self.use_sparse:
if self.use_sparse_c8_indexer:
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore
@@ -2862,8 +2885,8 @@ class NPUModelRunner(GPUModelRunner):
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel()
assert raw_k_tensor is not None
assert raw_v_tensor is not None
assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0
num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes
assert sum_page_size_bytes % current_kv_cache_spec.page_size_bytes == 0
num_blocks = sum_page_size_bytes // current_kv_cache_spec.page_size_bytes
# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
@@ -2877,48 +2900,54 @@ class NPUModelRunner(GPUModelRunner):
if hasattr(attn_backend, "get_supported_kernel_block_sizes") and self.use_hybrid_blocks:
block_size = attn_backend.get_supported_kernel_block_sizes()[0]
block_size_chunk = kv_cache_spec.block_size // block_size
block_size_chunk = current_kv_cache_spec.block_size // block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks * block_size_chunk,
block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
current_kv_cache_spec.num_kv_heads,
current_kv_cache_spec.head_size,
)
if self.hybrid_with_attn_and_mamba:
attn_tensor_page_size = int(np.prod(kv_cache_shape[1:])) * get_dtype_size(
kv_cache_spec.dtype
current_kv_cache_spec.dtype
)
conv_block_padding_size = raw_k_tensor.numel() - attn_tensor_page_size * 2
raw_kv_tensor = raw_k_tensor[conv_block_padding_size:]
raw_k_tensor = raw_kv_tensor[:attn_tensor_page_size]
raw_v_tensor = raw_kv_tensor[attn_tensor_page_size:]
else:
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks,
current_kv_cache_spec.block_size,
current_kv_cache_spec.num_kv_heads,
current_kv_cache_spec.head_size,
)
if not self.model_config.use_mla:
if not isinstance(current_kv_cache_spec, MLAAttentionSpec):
k_shape = kv_cache_shape[1:]
v_shape = k_shape
if hasattr(current_kv_cache_spec, "head_size_v"):
v_shape = (*kv_cache_shape[1:-1], current_kv_cache_spec.head_size_v)
else:
v_shape = k_shape
else:
# k_cache: nope_cache v_cache: rope_cache
mla_num_blocks, mla_block_size, num_kv_heads, _ = kv_cache_shape
k_shape = [
k_dim, v_dim = self._get_attention_kv_cache_dims(layer_name, current_kv_cache_spec)
k_shape = (
mla_num_blocks,
mla_block_size,
num_kv_heads,
self.model_config.hf_text_config.kv_lora_rank,
]
v_shape = [
k_dim,
)
v_shape = (
mla_num_blocks,
mla_block_size,
num_kv_heads,
self.model_config.hf_text_config.qk_rope_head_dim,
]
k_cache_dtype = v_cache_dtype = kv_cache_spec.dtype
v_dim,
)
k_cache_dtype = v_cache_dtype = current_kv_cache_spec.dtype
if self.is_kv_consumer and self.vllm_config.quant_config is not None:
k_cache_dtype, v_cache_dtype = self.vllm_config.quant_config.get_kv_quant_dtype(
layer_name, kv_cache_spec.dtype, self.model_config
layer_name, current_kv_cache_spec.dtype, self.model_config
)
k_cache = raw_k_tensor.view(k_cache_dtype).view(k_shape)
v_cache = raw_v_tensor.view(v_cache_dtype).view(v_shape)
@@ -2926,8 +2955,8 @@ class NPUModelRunner(GPUModelRunner):
if self.use_sparse:
dsa_k_cache_shape = (
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
current_kv_cache_spec.block_size,
current_kv_cache_spec.num_kv_heads,
self.model_config.hf_text_config.index_head_dim,
)
if self.use_sparse_c8_indexer:
@@ -2936,8 +2965,8 @@ class NPUModelRunner(GPUModelRunner):
# dsa_k_scale
dsa_k_scale_cache_shape = (
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
current_kv_cache_spec.block_size,
current_kv_cache_spec.num_kv_heads,
1,
)
assert raw_dsa_k_scale_tensor is not None
@@ -2949,15 +2978,15 @@ class NPUModelRunner(GPUModelRunner):
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache, dsa_k_scale_cache)
else:
# dsa_k
dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape)
dsa_k_cache = raw_dsa_k_tensor.view(current_kv_cache_spec.dtype).view(dsa_k_cache_shape)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
else:
kv_caches[layer_name] = (k_cache, v_cache)
elif isinstance(kv_cache_spec, MambaSpec):
elif isinstance(current_kv_cache_spec, MambaSpec):
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor is not None
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
assert raw_tensor.numel() % current_kv_cache_spec.page_size_bytes == 0
num_blocks = raw_tensor.numel() // current_kv_cache_spec.page_size_bytes
assert num_blocks >= kv_cache_config.num_blocks
# `num_blocks` is the number of blocks the model runner can use.
@@ -2977,7 +3006,7 @@ class NPUModelRunner(GPUModelRunner):
# tensor1: [(kv_padding), conv , ...]
# tensor2: [k , ssm , ...]
# tensor3: [v , (mamba_padding), ...]
for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes):
for shape, dtype in zip(current_kv_cache_spec.shapes, current_kv_cache_spec.dtypes):
# normally, there is conv state and ssm state in this loop. And there is only
# a conv state in some special models.
target_shape = (num_blocks, *shape)