From 80a42657173395741753753b392895ad71595b9f Mon Sep 17 00:00:00 2001 From: HongtaoYang <75939043+SidaoY@users.noreply.github.com> Date: Sat, 21 Mar 2026 10:48:01 +0800 Subject: [PATCH] [Feat] Support separate attention backend for target and draft model. (#7342) ### What this PR does / why we need it? This PR enables separate attention backend configuration for target and draft models in speculative decoding, decoupling the previously bound attention backend settings between the two models. It solves the compatibility issue where some draft models do not support the attention backend used by the target model, and allows users to select the optimal attention backend for each model individually to maximize inference performance. The change is fully backward compatible. --------- Signed-off-by: SidaoY <1024863041@qq.com> --- tests/ut/worker/test_model_runner_v1.py | 89 ++++++++++++++++ vllm_ascend/spec_decode/eagle_proposer.py | 16 ++- vllm_ascend/worker/model_runner_v1.py | 121 ++++++++++++++-------- 3 files changed, 177 insertions(+), 49 deletions(-) create mode 100644 tests/ut/worker/test_model_runner_v1.py diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py new file mode 100644 index 00000000..0cb09faf --- /dev/null +++ b/tests/ut/worker/test_model_runner_v1.py @@ -0,0 +1,89 @@ +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock + +import torch +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor + +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner + + +class TestNPUModelRunnerKVCache(unittest.TestCase): + + def _build_runner(self): + runner = NPUModelRunner.__new__(NPUModelRunner) + runner.device = torch.device("cpu") + runner.use_sparse = False + runner.use_sparse_c8_indexer = False + runner.use_hybrid_blocks = False + runner.hybrid_with_attn_and_mamba = False + runner.runner_only_attn_layers = set() + runner.is_kv_consumer = False + runner.vllm_config = MagicMock() + runner.vllm_config.kv_transfer_config = None + runner.model_config = MagicMock() + runner.model_config.use_mla = True + backend = MagicMock() + backend.get_kv_cache_shape.side_effect = lambda num_blocks, block_size, num_kv_heads, head_size: ( + 2, + num_blocks, + block_size, + num_kv_heads, + head_size, + ) + runner.attn_backend = backend + return runner + + def test_allocate_kv_cache_uses_layer_spec_for_draft_gqa(self): + runner = self._build_runner() + kv_cache_spec = FullAttentionSpec( + block_size=16, + num_kv_heads=8, + head_size=64, + head_size_v=64, + dtype=torch.float16, + ) + kv_cache_config = KVCacheConfig( + num_blocks=2, + kv_cache_tensors=[KVCacheTensor(size=kv_cache_spec.page_size_bytes * 2, shared_by=["draft_attn"])], + kv_cache_groups=[KVCacheGroupSpec(layer_names=["draft_attn"], kv_cache_spec=kv_cache_spec)], + ) + + kv_cache_raw_tensors = runner._allocate_kv_cache_tensors(kv_cache_config) + k_cache_raw, v_cache_raw = kv_cache_raw_tensors["draft_attn"] + + self.assertEqual(k_cache_raw.numel(), kv_cache_spec.page_size_bytes) + self.assertEqual(v_cache_raw.numel(), kv_cache_spec.page_size_bytes) + + def test_reshape_kv_cache_uses_layer_spec_for_draft_gqa(self): + runner = self._build_runner() + kv_cache_spec = FullAttentionSpec( + block_size=16, + num_kv_heads=8, + head_size=64, + head_size_v=64, + dtype=torch.float16, + ) + kv_cache_config = KVCacheConfig( + num_blocks=2, + kv_cache_tensors=[KVCacheTensor(size=kv_cache_spec.page_size_bytes * 2, shared_by=["draft_attn"])], + kv_cache_groups=[KVCacheGroupSpec(layer_names=["draft_attn"], kv_cache_spec=kv_cache_spec)], + ) + kv_cache_raw_tensors = runner._allocate_kv_cache_tensors(kv_cache_config) + runner._kv_cache_spec_attn_group_iterator = lambda: [ + SimpleNamespace( + kv_cache_spec=kv_cache_spec, + backend=runner.attn_backend, + layer_names=["draft_attn"], + ) + ] + + kv_caches = runner._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors) + k_cache, v_cache = kv_caches["draft_attn"] + + self.assertEqual(k_cache.shape, (2, 16, 8, 64)) + self.assertEqual(v_cache.shape, (2, 16, 8, 64)) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index b9b6c644..869cd287 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -46,7 +46,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num -from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled +from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is # Currently we will fix block size to a small one since `num_reqs` can't be too large _PREPARE_INPUTS_BLOCK_SIZE = 4 @@ -217,6 +217,8 @@ class SpecDecodeBaseProposer(EagleProposer): self.model.config.image_token_index = model.config.image_token_id elif self.get_model_name(model) == "PixtralForConditionalGeneration": self.model.config.image_token_index = model.config.vision_config.image_token_id + elif self.get_model_name(model) == "KimiK25ForConditionalGeneration": + self.model.config.image_token_index = model.config.media_placeholder_token_id else: self.model.config.image_token_index = model.config.image_token_index target_language_model = model.get_language_model() @@ -388,7 +390,11 @@ class SpecDecodeBaseProposer(EagleProposer): : num_reqs * self.decode_threshold ] - builder = self.runner.attn_groups[0][0].get_metadata_builder() + if vllm_version_is("0.17.0"): + assert len(self.draft_attn_groups) > 0 + builder = self.draft_attn_groups[0].get_metadata_builder() + else: + builder = self.runner.attn_groups[0][0].get_metadata_builder() # update the tensor's address for each step. for draft_step in range(self.num_speculative_tokens): common_attn_metadata = self.shallow_copy_metadata(common_attn_metadata) @@ -550,7 +556,11 @@ class SpecDecodeBaseProposer(EagleProposer): common_attn_metadata.slot_mapping = self.slot_mapping_group[0] common_attn_metadata.num_input_tokens = num_input_tokens # FIXME(woosuk): The below two ops cause synchronization. Optimize. - builder = self.runner.attn_groups[0][0].get_metadata_builder() + if vllm_version_is("0.17.0"): + assert len(self.draft_attn_groups) > 0 + builder = self.draft_attn_groups[0].get_metadata_builder() + else: + builder = self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model()) if self.uses_mrope: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a49182ce..d1f86bf0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2657,6 +2657,34 @@ class NPUModelRunner(GPUModelRunner): bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches, num_attn_module) return kv_caches + def _get_layer_kv_cache_specs(self, kv_cache_config: KVCacheConfig) -> dict[str, KVCacheSpec]: + layer_kv_cache_spec: dict[str, KVCacheSpec] = {} + for group_kv_cache_spec in kv_cache_config.kv_cache_groups: + group_spec = group_kv_cache_spec.kv_cache_spec + for layer_name in group_kv_cache_spec.layer_names: + if isinstance(group_spec, UniformTypeKVCacheSpecs): + layer_kv_cache_spec[layer_name] = group_spec.kv_cache_specs[layer_name] + else: + layer_kv_cache_spec[layer_name] = group_spec + return layer_kv_cache_spec + + def _get_attention_kv_cache_dims(self, layer_name: str, kv_cache_spec: AttentionSpec) -> tuple[int, int]: + if isinstance(kv_cache_spec, MLAAttentionSpec): + attn_layers = get_layers_from_vllm_config( + self.vllm_config, + AttentionLayerBase, + [layer_name], + ) + attn_layer = attn_layers[layer_name] + if not isinstance(attn_layer, MLAAttention): + raise TypeError( + f"Expected MLAAttention layer for {layer_name}, got {type(attn_layer).__name__}." + ) + return attn_layer.kv_lora_rank, attn_layer.qk_rope_head_dim + + head_size_v = kv_cache_spec.head_size_v if hasattr(kv_cache_spec, "head_size_v") else kv_cache_spec.head_size + return kv_cache_spec.head_size, head_size_v + def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ Initializes the KV cache buffer with the correct size. The buffer needs @@ -2677,10 +2705,7 @@ class NPUModelRunner(GPUModelRunner): kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None | None] = {} # prefill disaggregation need the addr of cache tensor be aligned with 2M alignment = 2 * 1024 * 1024 - layer_kv_cache_spec: dict[str, KVCacheSpec] = {} - for group_kv_cache_spec in kv_cache_config.kv_cache_groups: - for layer_name in group_kv_cache_spec.layer_names: - layer_kv_cache_spec[layer_name] = group_kv_cache_spec.kv_cache_spec + layer_kv_cache_spec = self._get_layer_kv_cache_specs(kv_cache_config) # If some tensors are shared by linear layers and attention layers, # the same tensor format must be maintained even if some layers # have only linear or attention layers, for example, the mtp layer. @@ -2715,11 +2740,10 @@ class NPUModelRunner(GPUModelRunner): # as it only support the 0-dim of kv_cache is `num_blocks`. # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim # and rope head dim. - if not self.model_config.use_mla: - # for non-mla model, use FullAttentionSpec - k_tensor_split_factor = 2.0 - v_tensor_split_factor = 2.0 - elif self.use_sparse: + current_kv_cache_spec = layer_kv_cache_spec[layer_name] + assert isinstance(current_kv_cache_spec, AttentionSpec) + + if self.use_sparse: # for deepseek v3.2, we split the kv cache according to the corresponding ratio kv_cache_spec = layer_kv_cache_spec[layer_name] sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio @@ -2728,10 +2752,11 @@ class NPUModelRunner(GPUModelRunner): dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2] dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3] else: - # for other deepseek models, use MLAAttentionSpec + k_dim, v_dim = self._get_attention_kv_cache_dims(layer_name, current_kv_cache_spec) + assert k_dim > 0 and v_dim > 0 kv_head_dim_list = [ - self.model_config.hf_text_config.kv_lora_rank, - self.model_config.hf_text_config.qk_rope_head_dim, + k_dim, + v_dim, ] if self.is_kv_consumer and self.vllm_config.quant_config is not None: k_tensor_split_factor, v_tensor_split_factor = ( @@ -2819,20 +2844,18 @@ class NPUModelRunner(GPUModelRunner): corresponding memory buffer for KV cache. """ kv_caches: dict[str, torch.Tensor] = {} - layer_kv_cache_spec = {} - for group in kv_cache_config.kv_cache_groups: - for layer_name in group.layer_names: - layer_kv_cache_spec[layer_name] = group.kv_cache_spec + layer_kv_cache_spec = self._get_layer_kv_cache_specs(kv_cache_config) for group in self._kv_cache_spec_attn_group_iterator(): - kv_cache_spec = group.kv_cache_spec attn_backend = group.backend for layer_name in group.layer_names: if layer_name in self.runner_only_attn_layers: continue + current_kv_cache_spec = layer_kv_cache_spec[layer_name] + # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue - if isinstance(kv_cache_spec, AttentionSpec): + if isinstance(current_kv_cache_spec, AttentionSpec): if self.use_sparse: if self.use_sparse_c8_indexer: raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore @@ -2862,8 +2885,8 @@ class NPUModelRunner(GPUModelRunner): sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel() assert raw_k_tensor is not None assert raw_v_tensor is not None - assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0 - num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes + assert sum_page_size_bytes % current_kv_cache_spec.page_size_bytes == 0 + num_blocks = sum_page_size_bytes // current_kv_cache_spec.page_size_bytes # `num_blocks` is the number of blocks the model runner can use. # `kv_cache_config.num_blocks` is the number of blocks that @@ -2877,48 +2900,54 @@ class NPUModelRunner(GPUModelRunner): if hasattr(attn_backend, "get_supported_kernel_block_sizes") and self.use_hybrid_blocks: block_size = attn_backend.get_supported_kernel_block_sizes()[0] - block_size_chunk = kv_cache_spec.block_size // block_size + block_size_chunk = current_kv_cache_spec.block_size // block_size kv_cache_shape = attn_backend.get_kv_cache_shape( num_blocks * block_size_chunk, block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, + current_kv_cache_spec.num_kv_heads, + current_kv_cache_spec.head_size, ) if self.hybrid_with_attn_and_mamba: attn_tensor_page_size = int(np.prod(kv_cache_shape[1:])) * get_dtype_size( - kv_cache_spec.dtype + current_kv_cache_spec.dtype ) conv_block_padding_size = raw_k_tensor.numel() - attn_tensor_page_size * 2 raw_kv_tensor = raw_k_tensor[conv_block_padding_size:] raw_k_tensor = raw_kv_tensor[:attn_tensor_page_size] raw_v_tensor = raw_kv_tensor[attn_tensor_page_size:] else: - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks, + current_kv_cache_spec.block_size, + current_kv_cache_spec.num_kv_heads, + current_kv_cache_spec.head_size, ) - - if not self.model_config.use_mla: + if not isinstance(current_kv_cache_spec, MLAAttentionSpec): k_shape = kv_cache_shape[1:] - v_shape = k_shape + if hasattr(current_kv_cache_spec, "head_size_v"): + v_shape = (*kv_cache_shape[1:-1], current_kv_cache_spec.head_size_v) + else: + v_shape = k_shape else: # k_cache: nope_cache v_cache: rope_cache mla_num_blocks, mla_block_size, num_kv_heads, _ = kv_cache_shape - k_shape = [ + k_dim, v_dim = self._get_attention_kv_cache_dims(layer_name, current_kv_cache_spec) + k_shape = ( mla_num_blocks, mla_block_size, num_kv_heads, - self.model_config.hf_text_config.kv_lora_rank, - ] - v_shape = [ + k_dim, + ) + v_shape = ( mla_num_blocks, mla_block_size, num_kv_heads, - self.model_config.hf_text_config.qk_rope_head_dim, - ] - k_cache_dtype = v_cache_dtype = kv_cache_spec.dtype + v_dim, + ) + k_cache_dtype = v_cache_dtype = current_kv_cache_spec.dtype if self.is_kv_consumer and self.vllm_config.quant_config is not None: k_cache_dtype, v_cache_dtype = self.vllm_config.quant_config.get_kv_quant_dtype( - layer_name, kv_cache_spec.dtype, self.model_config + layer_name, current_kv_cache_spec.dtype, self.model_config ) k_cache = raw_k_tensor.view(k_cache_dtype).view(k_shape) v_cache = raw_v_tensor.view(v_cache_dtype).view(v_shape) @@ -2926,8 +2955,8 @@ class NPUModelRunner(GPUModelRunner): if self.use_sparse: dsa_k_cache_shape = ( num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, + current_kv_cache_spec.block_size, + current_kv_cache_spec.num_kv_heads, self.model_config.hf_text_config.index_head_dim, ) if self.use_sparse_c8_indexer: @@ -2936,8 +2965,8 @@ class NPUModelRunner(GPUModelRunner): # dsa_k_scale dsa_k_scale_cache_shape = ( num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, + current_kv_cache_spec.block_size, + current_kv_cache_spec.num_kv_heads, 1, ) assert raw_dsa_k_scale_tensor is not None @@ -2949,15 +2978,15 @@ class NPUModelRunner(GPUModelRunner): kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache, dsa_k_scale_cache) else: # dsa_k - dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape) + dsa_k_cache = raw_dsa_k_tensor.view(current_kv_cache_spec.dtype).view(dsa_k_cache_shape) kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache) else: kv_caches[layer_name] = (k_cache, v_cache) - elif isinstance(kv_cache_spec, MambaSpec): + elif isinstance(current_kv_cache_spec, MambaSpec): raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor is not None - assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes + assert raw_tensor.numel() % current_kv_cache_spec.page_size_bytes == 0 + num_blocks = raw_tensor.numel() // current_kv_cache_spec.page_size_bytes assert num_blocks >= kv_cache_config.num_blocks # `num_blocks` is the number of blocks the model runner can use. @@ -2977,7 +3006,7 @@ class NPUModelRunner(GPUModelRunner): # tensor1: [(kv_padding), conv , ...] # tensor2: [k , ssm , ...] # tensor3: [v , (mamba_padding), ...] - for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): + for shape, dtype in zip(current_kv_cache_spec.shapes, current_kv_cache_spec.dtypes): # normally, there is conv state and ssm state in this loop. And there is only # a conv state in some special models. target_shape = (num_blocks, *shape)