[Refactor] Move AttentionSpec initialization to Attention module (#5834)
### What this PR does / why we need it?
This PR refactors `get_kv_cache_spec` method to delegate AttentionSpec
creation to each attention module's own `get_kv_cache_spec()` method,
aligning with the vllm source code structure.
**Changes:**
- Simplify `get_kv_cache_spec` in `model_runner_v1.py` and
`cpu_offload_connector.py`
- Remove manual `AttentionType` checks for `Attention` modules
- Delegate spec creation to each attention module's `get_kv_cache_spec`
method directly
- Let `MambaBase` layers use their own `get_kv_cache_spec` method
- Keep `use_sparse` hack for `MLAAttention` (DeepSeek DSA mode) as
Ascend-specific handling
This change follows RFC #5463 item 12: move AttentionSpec to Attention
module.
- Fixes #5463 (item 12)
### Does this PR introduce _any_ user-facing change?
No. This is an internal refactoring that simplifies code structure
without changing any external behavior.
### How was this patch tested?
- Syntax validation passed via `python -m py_compile`
- CI tests will verify the changes work correctly with existing test
cases
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -20,8 +20,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
MambaSpec, MLAAttentionSpec)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.kv_pool.cpu_offload.metadata import (
|
||||
MetadataServer, MetadataServerProc, MLAConfig)
|
||||
@@ -461,80 +460,45 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||
return {}
|
||||
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
|
||||
if vllm_config.cache_config.cache_dtype == "auto":
|
||||
kv_cache_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
vllm_config.cache_config.cache_dtype]
|
||||
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase)
|
||||
# NOTE: Must process Attention/MLAAttention before MambaBase to maintain
|
||||
# ordering expected by acl_graph.py's _update_attn_fia_params.
|
||||
mamba_layers: dict[str, MambaBase] = {}
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if isinstance(attn_module, Attention):
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
if spec := attn_module.get_kv_cache_spec(vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
|
||||
elif isinstance(attn_module, MLAAttention):
|
||||
if use_mla and not use_sparse:
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
cache_dtype_str=vllm_config.cache_config.cache_dtype)
|
||||
else:
|
||||
if use_sparse:
|
||||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||||
# using DSA. Fix the spec in vLLM is a finnal way.
|
||||
# using DSA. Fix the spec in vLLM is the final way.
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=kv_cache_dtype)
|
||||
elif spec := attn_module.get_kv_cache_spec(vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
|
||||
elif isinstance(attn_module, MambaBase):
|
||||
mamba_layers[layer_name] = attn_module
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(vllm_config, MambaBase)
|
||||
if len(mamba_layers) > 0:
|
||||
if (vllm_config.speculative_config is not None
|
||||
and vllm_config.model_config.hf_config.model_type
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
page_size_padded = (vllm_config.cache_config.mamba_page_size_padded)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
vllm_config.speculative_config.num_speculative_tokens
|
||||
if vllm_config.speculative_config else 0),
|
||||
)
|
||||
if spec := mamba_module.get_kv_cache_spec(vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
@@ -53,12 +53,11 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, CrossAttentionSpec,
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, MLAAttentionSpec,
|
||||
UniformTypeKVCacheSpecs)
|
||||
MambaSpec, UniformTypeKVCacheSpecs)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
LogprobsLists, LogprobsTensors, ModelRunnerOutput,
|
||||
SamplerOutput,
|
||||
@@ -2886,11 +2885,12 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if has_ec_transfer() and get_ec_transfer().is_producer:
|
||||
return {}
|
||||
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
AttentionLayerBase)
|
||||
# NOTE: Must process Attention/MLAAttention before MambaBase to maintain
|
||||
# ordering expected by acl_graph.py's _update_attn_fia_params.
|
||||
mamba_layers: dict[str, MambaBase] = {}
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if isinstance(attn_module, Attention):
|
||||
if (kv_tgt_layer :=
|
||||
@@ -2905,74 +2905,32 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
continue
|
||||
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
if spec := attn_module.get_kv_cache_spec(self.vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
|
||||
elif isinstance(attn_module, MLAAttention):
|
||||
if use_mla and not self.use_sparse:
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
cache_dtype_str=self.cache_config.cache_dtype)
|
||||
else:
|
||||
if self.use_sparse:
|
||||
# TODO(cmq): This is a hack way to fix deepseek kvcache when
|
||||
# using DSA. Fix the spec in vLLM is a finnal way.
|
||||
# using DSA. Fix the spec in vLLM is the final way.
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype)
|
||||
elif spec := attn_module.get_kv_cache_spec(self.vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
|
||||
elif isinstance(attn_module, MambaBase):
|
||||
mamba_layers[layer_name] = attn_module
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
||||
if len(mamba_layers) > 0:
|
||||
if (self.vllm_config.speculative_config is not None
|
||||
and self.vllm_config.model_config.hf_text_config.model_type
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = self.vllm_config.model_config.max_model_len
|
||||
|
||||
page_size_padded = (
|
||||
self.vllm_config.cache_config.mamba_page_size_padded)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0),
|
||||
)
|
||||
if spec := mamba_module.get_kv_cache_spec(self.vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
|
||||
Reference in New Issue
Block a user