[refactor]Optimized the kvcache usage of Deepseek v3.2 (#6610)

### What this PR does / why we need it?

For deepseek v3.2, DSA use FullAttentionSpec, allocate 2 * mla page size
bytes, and we use half of that for k cache in DSA

However, the actual proportion of k cache is not high, which results in
a large amount of kvcache being wasted. The proportion of discarded
kvcache is (576-128)/(576 x 2) = 0.388.

Run the same script to start DeepSeek V3.2 on a single A3 server. The
following shows the comparison of kvcache usage:
Before refactoring
```
[kv_cache_utils.py:1307] GPU KV cache size: 15,872 tokens
```
After refactoring
```
[kv_cache_utils.py:1307] GPU KV cache size: 25,984 tokens
```

This pull request refactors the KV cache allocation for Deepseek v3.2
models that use sparse attention. It replaces the use of
`FullAttentionSpec` with `MLAAttentionSpec` and introduces a more
principled way of calculating KV cache tensor split factors based on
model configuration.

This change removes hardcoded values and correctly sizes the cache
tensors, leading to optimized memory usage and improved code
maintainability.

### Does this PR introduce _any_ user-facing change?

No, this is an internal optimization and does not introduce any
user-facing changes.

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2026-02-09 18:53:56 +08:00
committed by GitHub
parent cb7c419bc0
commit 156976b982

View File

@@ -53,11 +53,11 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (
AttentionSpec,
EncoderOnlyAttentionSpec,
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
KVCacheSpec,
MambaSpec,
MLAAttentionSpec,
UniformTypeKVCacheSpecs,
)
from vllm.v1.outputs import (
@@ -2438,12 +2438,11 @@ class NPUModelRunner(GPUModelRunner):
k_tensor_split_factor = 2
v_tensor_split_factor = 2
elif self.use_sparse:
# for deepseek v3.2, DSA use FullAttentionSpec
# FullAttentionSpec allocate 2 * mla page size bytes,
# and we use half of that for k cache in DSA
dsa_k_cache_factor = 2
k_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.kv_lora_rank
v_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.qk_rope_head_dim
# for deepseek v3.2, we split the kv cache according to the corresponding ratio
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
k_tensor_split_factor, v_tensor_split_factor, dsa_k_cache_factor = [ # type: ignore
sparse_sum_head_size / ratio for ratio in self._get_sparse_kv_cache_ratio()
]
dsa_k_cache_size = int(kv_cache_tensor.size // dsa_k_cache_factor)
else:
# for other deepseek models, use MLAAttentionSpec
@@ -2581,9 +2580,14 @@ class NPUModelRunner(GPUModelRunner):
v_cache = raw_v_tensor.view(dtype).view(v_shape)
if self.use_sparse and raw_dsa_k_tensor is not None:
dsa_k_cache_shape = (num_blocks, kv_cache_spec.block_size, 1, 128)
dsa_k_cache_size = (num_blocks) * kv_cache_spec.block_size * 128 * dtype.itemsize
dsa_k_cache = raw_dsa_k_tensor[:dsa_k_cache_size].view(dtype).view(dsa_k_cache_shape)
index_head_dim = self._get_sparse_kv_cache_ratio()[-1]
dsa_k_cache_shape = (
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
index_head_dim,
)
dsa_k_cache = raw_dsa_k_tensor.view(dtype).view(dsa_k_cache_shape)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
else:
kv_caches[layer_name] = (k_cache, v_cache)
@@ -2832,12 +2836,13 @@ class NPUModelRunner(GPUModelRunner):
if self.use_sparse:
# TODO(cmq): This is a hack way to fix deepseek kvcache when
# using DSA. Fix the spec in vLLM is the final way.
block_size = self.vllm_config.cache_config.block_size
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
sparse_sum_head_size = sum(self._get_sparse_kv_cache_ratio())
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=self.block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
head_size=sparse_sum_head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=self.vllm_config.cache_config.cache_dtype,
)
elif spec := attn_module.get_kv_cache_spec(self.vllm_config):
kv_cache_spec[layer_name] = spec
@@ -2854,6 +2859,16 @@ class NPUModelRunner(GPUModelRunner):
return kv_cache_spec
def _get_sparse_kv_cache_ratio(self) -> list[int]:
# TODO:If C8 is supported, we need to consider the number of bytes occupied by different dtypes
# when calculating the ratiofor example:
# [kv_lora_rank * torch.int8.itemsize, qk_rope_head_dim * torch.bfloat16.itemsize, ...]
return [
self.model_config.hf_text_config.kv_lora_rank,
self.model_config.hf_text_config.qk_rope_head_dim,
self.model_config.hf_text_config.index_head_dim,
]
def _check_and_update_cudagraph_mode(
self,
attention_backends: list[set[type[AttentionBackend]]],