[Feat] Support separate attention backend for target and draft model. (#7342)
### What this PR does / why we need it? This PR enables separate attention backend configuration for target and draft models in speculative decoding, decoupling the previously bound attention backend settings between the two models. It solves the compatibility issue where some draft models do not support the attention backend used by the target model, and allows users to select the optimal attention backend for each model individually to maximize inference performance. The change is fully backward compatible. --------- Signed-off-by: SidaoY <1024863041@qq.com>
This commit is contained in:
89
tests/ut/worker/test_model_runner_v1.py
Normal file
89
tests/ut/worker/test_model_runner_v1.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor
|
||||
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
class TestNPUModelRunnerKVCache(unittest.TestCase):
|
||||
|
||||
def _build_runner(self):
|
||||
runner = NPUModelRunner.__new__(NPUModelRunner)
|
||||
runner.device = torch.device("cpu")
|
||||
runner.use_sparse = False
|
||||
runner.use_sparse_c8_indexer = False
|
||||
runner.use_hybrid_blocks = False
|
||||
runner.hybrid_with_attn_and_mamba = False
|
||||
runner.runner_only_attn_layers = set()
|
||||
runner.is_kv_consumer = False
|
||||
runner.vllm_config = MagicMock()
|
||||
runner.vllm_config.kv_transfer_config = None
|
||||
runner.model_config = MagicMock()
|
||||
runner.model_config.use_mla = True
|
||||
backend = MagicMock()
|
||||
backend.get_kv_cache_shape.side_effect = lambda num_blocks, block_size, num_kv_heads, head_size: (
|
||||
2,
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
)
|
||||
runner.attn_backend = backend
|
||||
return runner
|
||||
|
||||
def test_allocate_kv_cache_uses_layer_spec_for_draft_gqa(self):
|
||||
runner = self._build_runner()
|
||||
kv_cache_spec = FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=8,
|
||||
head_size=64,
|
||||
head_size_v=64,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=2,
|
||||
kv_cache_tensors=[KVCacheTensor(size=kv_cache_spec.page_size_bytes * 2, shared_by=["draft_attn"])],
|
||||
kv_cache_groups=[KVCacheGroupSpec(layer_names=["draft_attn"], kv_cache_spec=kv_cache_spec)],
|
||||
)
|
||||
|
||||
kv_cache_raw_tensors = runner._allocate_kv_cache_tensors(kv_cache_config)
|
||||
k_cache_raw, v_cache_raw = kv_cache_raw_tensors["draft_attn"]
|
||||
|
||||
self.assertEqual(k_cache_raw.numel(), kv_cache_spec.page_size_bytes)
|
||||
self.assertEqual(v_cache_raw.numel(), kv_cache_spec.page_size_bytes)
|
||||
|
||||
def test_reshape_kv_cache_uses_layer_spec_for_draft_gqa(self):
|
||||
runner = self._build_runner()
|
||||
kv_cache_spec = FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=8,
|
||||
head_size=64,
|
||||
head_size_v=64,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=2,
|
||||
kv_cache_tensors=[KVCacheTensor(size=kv_cache_spec.page_size_bytes * 2, shared_by=["draft_attn"])],
|
||||
kv_cache_groups=[KVCacheGroupSpec(layer_names=["draft_attn"], kv_cache_spec=kv_cache_spec)],
|
||||
)
|
||||
kv_cache_raw_tensors = runner._allocate_kv_cache_tensors(kv_cache_config)
|
||||
runner._kv_cache_spec_attn_group_iterator = lambda: [
|
||||
SimpleNamespace(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
backend=runner.attn_backend,
|
||||
layer_names=["draft_attn"],
|
||||
)
|
||||
]
|
||||
|
||||
kv_caches = runner._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors)
|
||||
k_cache, v_cache = kv_caches["draft_attn"]
|
||||
|
||||
self.assertEqual(k_cache.shape, (2, 16, 8, 64))
|
||||
self.assertEqual(v_cache.shape, (2, 16, 8, 64))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -46,7 +46,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper, update_full_graph_params
|
||||
from vllm_ascend.ops.triton.spec_decode.utils import prepare_inputs_padded_kernel
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled
|
||||
from vllm_ascend.utils import enable_sp, lmhead_tp_enable, shared_expert_dp_enabled, vllm_version_is
|
||||
|
||||
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
||||
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
||||
@@ -217,6 +217,8 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
self.model.config.image_token_index = model.config.image_token_id
|
||||
elif self.get_model_name(model) == "PixtralForConditionalGeneration":
|
||||
self.model.config.image_token_index = model.config.vision_config.image_token_id
|
||||
elif self.get_model_name(model) == "KimiK25ForConditionalGeneration":
|
||||
self.model.config.image_token_index = model.config.media_placeholder_token_id
|
||||
else:
|
||||
self.model.config.image_token_index = model.config.image_token_index
|
||||
target_language_model = model.get_language_model()
|
||||
@@ -388,7 +390,11 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
: num_reqs * self.decode_threshold
|
||||
]
|
||||
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
if vllm_version_is("0.17.0"):
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
builder = self.draft_attn_groups[0].get_metadata_builder()
|
||||
else:
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
# update the tensor's address for each step.
|
||||
for draft_step in range(self.num_speculative_tokens):
|
||||
common_attn_metadata = self.shallow_copy_metadata(common_attn_metadata)
|
||||
@@ -550,7 +556,11 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
common_attn_metadata.slot_mapping = self.slot_mapping_group[0]
|
||||
common_attn_metadata.num_input_tokens = num_input_tokens
|
||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
if vllm_version_is("0.17.0"):
|
||||
assert len(self.draft_attn_groups) > 0
|
||||
builder = self.draft_attn_groups[0].get_metadata_builder()
|
||||
else:
|
||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||
attn_metadata = builder.build(0, common_attn_metadata, self.runner.get_model())
|
||||
|
||||
if self.uses_mrope:
|
||||
|
||||
@@ -2657,6 +2657,34 @@ class NPUModelRunner(GPUModelRunner):
|
||||
bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches, num_attn_module)
|
||||
return kv_caches
|
||||
|
||||
def _get_layer_kv_cache_specs(self, kv_cache_config: KVCacheConfig) -> dict[str, KVCacheSpec]:
|
||||
layer_kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for group_kv_cache_spec in kv_cache_config.kv_cache_groups:
|
||||
group_spec = group_kv_cache_spec.kv_cache_spec
|
||||
for layer_name in group_kv_cache_spec.layer_names:
|
||||
if isinstance(group_spec, UniformTypeKVCacheSpecs):
|
||||
layer_kv_cache_spec[layer_name] = group_spec.kv_cache_specs[layer_name]
|
||||
else:
|
||||
layer_kv_cache_spec[layer_name] = group_spec
|
||||
return layer_kv_cache_spec
|
||||
|
||||
def _get_attention_kv_cache_dims(self, layer_name: str, kv_cache_spec: AttentionSpec) -> tuple[int, int]:
|
||||
if isinstance(kv_cache_spec, MLAAttentionSpec):
|
||||
attn_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config,
|
||||
AttentionLayerBase,
|
||||
[layer_name],
|
||||
)
|
||||
attn_layer = attn_layers[layer_name]
|
||||
if not isinstance(attn_layer, MLAAttention):
|
||||
raise TypeError(
|
||||
f"Expected MLAAttention layer for {layer_name}, got {type(attn_layer).__name__}."
|
||||
)
|
||||
return attn_layer.kv_lora_rank, attn_layer.qk_rope_head_dim
|
||||
|
||||
head_size_v = kv_cache_spec.head_size_v if hasattr(kv_cache_spec, "head_size_v") else kv_cache_spec.head_size
|
||||
return kv_cache_spec.head_size, head_size_v
|
||||
|
||||
def _allocate_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initializes the KV cache buffer with the correct size. The buffer needs
|
||||
@@ -2677,10 +2705,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor | torch.Tensor | None | None] = {}
|
||||
# prefill disaggregation need the addr of cache tensor be aligned with 2M
|
||||
alignment = 2 * 1024 * 1024
|
||||
layer_kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for group_kv_cache_spec in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group_kv_cache_spec.layer_names:
|
||||
layer_kv_cache_spec[layer_name] = group_kv_cache_spec.kv_cache_spec
|
||||
layer_kv_cache_spec = self._get_layer_kv_cache_specs(kv_cache_config)
|
||||
# If some tensors are shared by linear layers and attention layers,
|
||||
# the same tensor format must be maintained even if some layers
|
||||
# have only linear or attention layers, for example, the mtp layer.
|
||||
@@ -2715,11 +2740,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# as it only support the 0-dim of kv_cache is `num_blocks`.
|
||||
# For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim
|
||||
# and rope head dim.
|
||||
if not self.model_config.use_mla:
|
||||
# for non-mla model, use FullAttentionSpec
|
||||
k_tensor_split_factor = 2.0
|
||||
v_tensor_split_factor = 2.0
|
||||
elif self.use_sparse:
|
||||
current_kv_cache_spec = layer_kv_cache_spec[layer_name]
|
||||
assert isinstance(current_kv_cache_spec, AttentionSpec)
|
||||
|
||||
if self.use_sparse:
|
||||
# for deepseek v3.2, we split the kv cache according to the corresponding ratio
|
||||
kv_cache_spec = layer_kv_cache_spec[layer_name]
|
||||
sparse_kv_cache_ratio = kv_cache_spec.sparse_kv_cache_ratio
|
||||
@@ -2728,10 +2752,11 @@ class NPUModelRunner(GPUModelRunner):
|
||||
dsa_k_tensor_split_factor = sparse_kv_cache_ratio[2]
|
||||
dsa_k_scale_tensor_split_factor = sparse_kv_cache_ratio[3]
|
||||
else:
|
||||
# for other deepseek models, use MLAAttentionSpec
|
||||
k_dim, v_dim = self._get_attention_kv_cache_dims(layer_name, current_kv_cache_spec)
|
||||
assert k_dim > 0 and v_dim > 0
|
||||
kv_head_dim_list = [
|
||||
self.model_config.hf_text_config.kv_lora_rank,
|
||||
self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
k_dim,
|
||||
v_dim,
|
||||
]
|
||||
if self.is_kv_consumer and self.vllm_config.quant_config is not None:
|
||||
k_tensor_split_factor, v_tensor_split_factor = (
|
||||
@@ -2819,20 +2844,18 @@ class NPUModelRunner(GPUModelRunner):
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
layer_kv_cache_spec = {}
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
layer_kv_cache_spec[layer_name] = group.kv_cache_spec
|
||||
layer_kv_cache_spec = self._get_layer_kv_cache_specs(kv_cache_config)
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
|
||||
current_kv_cache_spec = layer_kv_cache_spec[layer_name]
|
||||
|
||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||
# encounter OOM issue
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
if isinstance(current_kv_cache_spec, AttentionSpec):
|
||||
if self.use_sparse:
|
||||
if self.use_sparse_c8_indexer:
|
||||
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor, raw_dsa_k_scale_tensor = kv_cache_raw_tensors[ # type: ignore
|
||||
@@ -2862,8 +2885,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
sum_page_size_bytes = raw_k_tensor.numel() + raw_v_tensor.numel()
|
||||
assert raw_k_tensor is not None
|
||||
assert raw_v_tensor is not None
|
||||
assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes
|
||||
assert sum_page_size_bytes % current_kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = sum_page_size_bytes // current_kv_cache_spec.page_size_bytes
|
||||
|
||||
# `num_blocks` is the number of blocks the model runner can use.
|
||||
# `kv_cache_config.num_blocks` is the number of blocks that
|
||||
@@ -2877,48 +2900,54 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if hasattr(attn_backend, "get_supported_kernel_block_sizes") and self.use_hybrid_blocks:
|
||||
block_size = attn_backend.get_supported_kernel_block_sizes()[0]
|
||||
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
block_size_chunk = current_kv_cache_spec.block_size // block_size
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks * block_size_chunk,
|
||||
block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
current_kv_cache_spec.num_kv_heads,
|
||||
current_kv_cache_spec.head_size,
|
||||
)
|
||||
if self.hybrid_with_attn_and_mamba:
|
||||
attn_tensor_page_size = int(np.prod(kv_cache_shape[1:])) * get_dtype_size(
|
||||
kv_cache_spec.dtype
|
||||
current_kv_cache_spec.dtype
|
||||
)
|
||||
conv_block_padding_size = raw_k_tensor.numel() - attn_tensor_page_size * 2
|
||||
raw_kv_tensor = raw_k_tensor[conv_block_padding_size:]
|
||||
raw_k_tensor = raw_kv_tensor[:attn_tensor_page_size]
|
||||
raw_v_tensor = raw_kv_tensor[attn_tensor_page_size:]
|
||||
else:
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
current_kv_cache_spec.block_size,
|
||||
current_kv_cache_spec.num_kv_heads,
|
||||
current_kv_cache_spec.head_size,
|
||||
)
|
||||
|
||||
if not self.model_config.use_mla:
|
||||
if not isinstance(current_kv_cache_spec, MLAAttentionSpec):
|
||||
k_shape = kv_cache_shape[1:]
|
||||
v_shape = k_shape
|
||||
if hasattr(current_kv_cache_spec, "head_size_v"):
|
||||
v_shape = (*kv_cache_shape[1:-1], current_kv_cache_spec.head_size_v)
|
||||
else:
|
||||
v_shape = k_shape
|
||||
else:
|
||||
# k_cache: nope_cache v_cache: rope_cache
|
||||
mla_num_blocks, mla_block_size, num_kv_heads, _ = kv_cache_shape
|
||||
k_shape = [
|
||||
k_dim, v_dim = self._get_attention_kv_cache_dims(layer_name, current_kv_cache_spec)
|
||||
k_shape = (
|
||||
mla_num_blocks,
|
||||
mla_block_size,
|
||||
num_kv_heads,
|
||||
self.model_config.hf_text_config.kv_lora_rank,
|
||||
]
|
||||
v_shape = [
|
||||
k_dim,
|
||||
)
|
||||
v_shape = (
|
||||
mla_num_blocks,
|
||||
mla_block_size,
|
||||
num_kv_heads,
|
||||
self.model_config.hf_text_config.qk_rope_head_dim,
|
||||
]
|
||||
k_cache_dtype = v_cache_dtype = kv_cache_spec.dtype
|
||||
v_dim,
|
||||
)
|
||||
k_cache_dtype = v_cache_dtype = current_kv_cache_spec.dtype
|
||||
if self.is_kv_consumer and self.vllm_config.quant_config is not None:
|
||||
k_cache_dtype, v_cache_dtype = self.vllm_config.quant_config.get_kv_quant_dtype(
|
||||
layer_name, kv_cache_spec.dtype, self.model_config
|
||||
layer_name, current_kv_cache_spec.dtype, self.model_config
|
||||
)
|
||||
k_cache = raw_k_tensor.view(k_cache_dtype).view(k_shape)
|
||||
v_cache = raw_v_tensor.view(v_cache_dtype).view(v_shape)
|
||||
@@ -2926,8 +2955,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.use_sparse:
|
||||
dsa_k_cache_shape = (
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
current_kv_cache_spec.block_size,
|
||||
current_kv_cache_spec.num_kv_heads,
|
||||
self.model_config.hf_text_config.index_head_dim,
|
||||
)
|
||||
if self.use_sparse_c8_indexer:
|
||||
@@ -2936,8 +2965,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# dsa_k_scale
|
||||
dsa_k_scale_cache_shape = (
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
current_kv_cache_spec.block_size,
|
||||
current_kv_cache_spec.num_kv_heads,
|
||||
1,
|
||||
)
|
||||
assert raw_dsa_k_scale_tensor is not None
|
||||
@@ -2949,15 +2978,15 @@ class NPUModelRunner(GPUModelRunner):
|
||||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache, dsa_k_scale_cache)
|
||||
else:
|
||||
# dsa_k
|
||||
dsa_k_cache = raw_dsa_k_tensor.view(kv_cache_spec.dtype).view(dsa_k_cache_shape)
|
||||
dsa_k_cache = raw_dsa_k_tensor.view(current_kv_cache_spec.dtype).view(dsa_k_cache_shape)
|
||||
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
|
||||
else:
|
||||
kv_caches[layer_name] = (k_cache, v_cache)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
elif isinstance(current_kv_cache_spec, MambaSpec):
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor is not None
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
assert raw_tensor.numel() % current_kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = raw_tensor.numel() // current_kv_cache_spec.page_size_bytes
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
|
||||
# `num_blocks` is the number of blocks the model runner can use.
|
||||
@@ -2977,7 +3006,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# tensor1: [(kv_padding), conv , ...]
|
||||
# tensor2: [k , ssm , ...]
|
||||
# tensor3: [v , (mamba_padding), ...]
|
||||
for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes):
|
||||
for shape, dtype in zip(current_kv_cache_spec.shapes, current_kv_cache_spec.dtypes):
|
||||
# normally, there is conv state and ssm state in this loop. And there is only
|
||||
# a conv state in some special models.
|
||||
target_shape = (num_blocks, *shape)
|
||||
|
||||
Reference in New Issue
Block a user