[KVCache] Refactor KVCache as page_size_bytes is ineffective (#3438)

### What this PR does / why we need it?
Refactor KVCache as page_size_bytes is ineffective.

1. Currently the `AttentionSpec` is patched, but the `page_size_bytes`
is still using that in vLLM in runtime, thus the patch is not working
actually. Thus this pr removes the patch on `AttentionSpec`, and will do
the final fix in vLLM.
2. Use `MLAAttentionSpec` instead of `FullAttentionSpec` to reduce
`page_size_bytes` of spec, so that num_blocks in spec could double

### How was this patch tested?
Test pass with Qwen3-Next and DeepSeek-V3.2-Exp

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-10-14 21:28:41 +08:00
committed by GitHub
parent c55d99d13e
commit 223cc34085
6 changed files with 38 additions and 131 deletions

View File

@@ -18,8 +18,10 @@ from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.utils import logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
MLAAttentionSpec)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.cpu_offload_manager.metadata import (
MetadataServer, MetadataServerProc, MLAConfig)
@@ -434,18 +436,30 @@ def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
forward_ctx = vllm_config.compilation_config.static_forward_context
block_size = vllm_config.cache_config.block_size
use_mla = vllm_config.model_config.use_mla
ascend_config = get_ascend_config()
use_sfa = ascend_config.use_sfa
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
if isinstance(attn_module, FusedMoE):
continue
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
use_mla=use_mla)
if use_mla and not use_sfa:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype)
else:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is a finnal way.
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=attn_module.dtype)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
continue

View File

@@ -19,4 +19,3 @@ import vllm_ascend.patch.platform.patch_common.patch_config # noqa
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa

View File

@@ -6,8 +6,6 @@ from vllm.model_executor.models.config import MambaModelConfig
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
from vllm_ascend.ascend_config import get_ascend_config
@classmethod
def verify_and_update_config(cls, vllm_config) -> None:
@@ -24,7 +22,6 @@ def verify_and_update_config(cls, vllm_config) -> None:
logger = init_logger(__name__)
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
ascend_config = get_ascend_config()
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
@@ -40,8 +37,7 @@ def verify_and_update_config(cls, vllm_config) -> None:
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
use_mla=model_config.use_mla or ascend_config.use_sfa).page_size_bytes
dtype=kv_cache_dtype).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,

View File

@@ -22,7 +22,6 @@ if HAS_TRITON:
# isort: off
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa
import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa

View File

@@ -1,110 +0,0 @@
from dataclasses import dataclass, fields
from typing import Optional
import torch
import vllm
from typing_extensions import Self
from vllm.config import VllmConfig
from vllm.utils import cdiv, get_dtype_size
from vllm.v1.core.single_type_kv_cache_manager import (FullAttentionManager,
spec_manager_map)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
@dataclass(frozen=True)
class AttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
dtype: torch.dtype
use_mla: bool
use_sfa: bool
@property
def page_size_bytes(self) -> int:
# For MLA we only store a single latent vector
coef = 1 if self.use_mla else 2
sfa_bytes = 128 * self.block_size * get_dtype_size(
self.dtype) if self.use_sfa else 0
return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype) + sfa_bytes
vllm.v1.kv_cache_interface.AttentionSpec = AttentionSpec
@dataclass(frozen=True)
class AscendFullAttentionSpec(FullAttentionSpec, AttentionSpec):
sliding_window: Optional[int] = None
attention_chunk_size: Optional[int] = None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
window attention are regarded as full attention in KV cache manager
(blocks are allocated for all tokens), while computed as sliding window
attention in model runner.
In this case, we use FullAttentionSpec and record the sliding window size.
Default to None for not using sliding window attention.
"""
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = \
vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
if len(window_sizes) == 0:
return None
elif len(window_sizes) == 1:
return window_sizes.pop()
else:
raise ValueError(
"All attention layers in the same KV cache group must have the "
"same window size.")
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"FullAttentionSpec.")
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
if spec.attention_chunk_size is not None)
merged_spec = cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
use_mla=specs[0].use_mla,
use_sfa=specs[0].use_sfa,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
)
for spec in specs:
for f in fields(AttentionSpec):
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
"All attention layers in the same KV cache group must have "
"the same attention spec.")
assert (
(merged_spec.sliding_window is not None) +
(merged_spec.attention_chunk_size is not None) <= 1
), ("Model with both sliding window layers and chunked local attention "
"layers is not supported.")
return merged_spec
spec_manager_map.update({AscendFullAttentionSpec: FullAttentionManager})
vllm.v1.kv_cache_interface.FullAttentionSpec = AscendFullAttentionSpec

View File

@@ -80,6 +80,7 @@ from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheGroupSpec,
KVCacheSpec, MambaSpec,
MLAAttentionSpec,
UniformTypeKVCacheSpecs)
# yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
@@ -3220,13 +3221,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla,
use_sfa=use_sfa)
if use_mla and not use_sfa:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=self.cache_config.cache_dtype)
else:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is a finnal way.
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.