[ModelRunner][Refactor] Refactor kv cache tensor initialization logic (#3106)

### What this PR does / why we need it?
Refactor kv cache tensor initialization logic. 
1. Unify the kvcache tensor initialization logic of deepseek and normal
models
2. spilt `initialize_kv_cache_tensors` into `_allocate_kv_cache_tensors`
and `_reshape_kv_cache_tensors`, following gpu modelrunner in vllm

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with existing test.
1. prefill disaggregation scenario
4. deepseek + aclgraph/eager mode
5. qwen3 next


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-11-04 17:26:54 +08:00
committed by GitHub
parent bedf223771
commit 5fed166a99

View File

@@ -3138,15 +3138,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
])
self.may_reinitialize_input_batch(kv_cache_config)
if self.use_sparse:
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
kv_cache_config)
elif self.model_config.is_deepseek_mla:
kv_caches = self.initialize_kv_cache_tensors_deepseek_mla(
kv_cache_config)
else:
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
@@ -3158,197 +3150,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]
def initialize_kv_cache_tensors_deepseek_sfa(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
kv_cache_sizes = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
assert len(kv_cache_tensor.shared_by) == 1, (
"KV cache tensor shared by multiple layers is not supported in "
"NPU.")
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
kv_caches: Dict[str, torch.Tensor] = {}
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
tensor_size = kv_cache_sizes[layer_name]
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
if self.vllm_config.additional_config.get(
"kv_cache_dtype", None) == 'int8':
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
elif hasattr(
attn_backend, "get_supported_block_size"
) and not self.model_config.is_deepseek_mla and not self.use_sparse:
block_size = attn_backend.get_supported_block_size()[0]
block_size_chunk = kv_cache_spec.block_size // block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks * block_size_chunk, block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
else:
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
alignment = 2 * 1024 * 1024
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
nope_dim = head_size - rope_dim
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
nope_dim)
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
rope_dim)
#### k cache
# TODO(zzzzwwjj): wait transformers add these params
k_cache_shape = (num_blocks, block_size, 1, 128)
if self.vllm_config.kv_transfer_config is None:
# For no disaggregate pd scenario, allocate kv cache in normal way
rope_cache = torch.zeros(rope_cache_shape,
dtype=dtype,
device=self.device)
nope_cache = torch.zeros(nope_cache_shape,
dtype=dtype,
device=self.device)
rope_cache = self._convert_torch_format(rope_cache)
nope_cache = self._convert_torch_format(nope_cache)
#### k cache
k_cache = torch.zeros(k_cache_shape,
dtype=dtype,
device=self.device)
k_cache = self._convert_torch_format(k_cache)
else:
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
# we found there are also some exceptions during test, so we manual align those memory here, this part
# of code may consume 2M * 2 * elem_size memory every layer.
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
nope_allocate_shape_alignment = nope_allocate_shape + alignment
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
rope_allocate_shape_alignment = rope_allocate_shape + alignment
nope_cache = torch.zeros(nope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
rope_cache = torch.zeros(rope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
#### k cache
# TODO(zzzzwwjj): wait transformers add these params
k_allocate_shape = num_blocks * block_size * 1 * 128
k_allocate_shape_alignment = k_allocate_shape + alignment
k_cache = torch.zeros(k_allocate_shape_alignment,
dtype=dtype,
device=self.device)
nope_cache = self._align_memory(
nope_cache,
alignment)[:nope_allocate_shape].view(nope_cache_shape)
rope_cache = self._align_memory(
rope_cache,
alignment)[:rope_allocate_shape].view(rope_cache_shape)
k_cache = self._align_memory(
k_cache,
alignment)[:k_allocate_shape].view(k_cache_shape)
kv_caches[layer_name] = (nope_cache, rope_cache, k_cache)
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def initialize_kv_cache_tensors_deepseek_mla(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
kv_cache_sizes = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
assert len(kv_cache_tensor.shared_by) == 1, (
"KV cache tensor shared by multiple layers is not supported in "
"NPU.")
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
kv_caches: Dict[str, torch.Tensor] = {}
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
tensor_size = kv_cache_sizes[layer_name]
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
if self.vllm_config.additional_config.get(
"kv_cache_dtype", None) == 'int8':
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
elif hasattr(attn_backend, "get_supported_block_size"
) and not self.model_config.is_deepseek_mla:
block_size = attn_backend.get_supported_block_size()[0]
block_size_chunk = kv_cache_spec.block_size // block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks * block_size_chunk, block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
else:
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
alignment = 2 * 1024 * 1024
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
nope_dim = head_size - rope_dim
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
nope_dim)
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
rope_dim)
if self.vllm_config.kv_transfer_config is None:
# For no disaggregate pd scenario, allocate kv cache in normal way
rope_cache = torch.zeros(rope_cache_shape,
dtype=dtype,
device=self.device)
nope_cache = torch.zeros(nope_cache_shape,
dtype=dtype,
device=self.device)
rope_cache = self._convert_torch_format(rope_cache)
nope_cache = self._convert_torch_format(nope_cache)
else:
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
# we found there are also some exceptions during test, so we manual align those memory here, this part
# of code may consume 2M * 2 * elem_size memory every layer.
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
nope_allocate_shape_alignment = nope_allocate_shape + alignment
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
rope_allocate_shape_alignment = rope_allocate_shape + alignment
nope_cache = torch.zeros(nope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
rope_cache = torch.zeros(rope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
nope_cache = self._align_memory(
nope_cache,
alignment)[:nope_allocate_shape].view(nope_cache_shape)
rope_cache = self._align_memory(
rope_cache,
alignment)[:rope_allocate_shape].view(rope_cache_shape)
kv_caches[layer_name] = (nope_cache, rope_cache)
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
@@ -3360,6 +3161,34 @@ class NPUModelRunner(LoRAModelRunnerMixin):
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
# Initialize the memory buffer for KV cache
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
# Change the memory buffer to the desired shape
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
kv_cache_raw_tensors)
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def _allocate_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Initializes the KV cache buffer with the correct size. The buffer needs
to be reshaped to the desired shape before being used by the models.
NOTE: To support prefill disaggregation, we need to split kvcache tensor into
k_cahce and v cache, and the addr of both are aligned by 2M
Args:
kv_cache_config: The KV cache config
Returns:
dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
dict[str, tuple(torch.Tensor, torch.Tensor)] A map between layer names
to their corresponding memory buffer for K cache and V cache.
"""
# init kv cache tensors
kv_cache_raw_tensors: dict[str, Union[torch.Tensor,
Optional[torch.Tensor]]] = {}
@@ -3383,39 +3212,88 @@ class NPUModelRunner(LoRAModelRunnerMixin):
device=self.device)
tensor = self._align_memory(
tensor, alignment)[:kv_cache_tensor.size]
for layer_name_inner in kv_cache_tensor.shared_by:
# shared the kvcache between the linear_attn specs in the same group
# shared the kvcache between the self_attn specs in the same group
if "linear_attn" in layer_name_inner:
kv_cache_raw_tensors[layer_name_inner] = tensor
elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys(
):
# NOTE: We need to init k cache tensor (nope cache tensor in mla) and
# v cache tensor (rope cache tensor in mla) separately to support llmdatadist,
# as it only support the 0-dim of kv_cache is `num_blocks`.
# For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim
# and rope head dim.
if self.model_config.is_deepseek_mla:
head_size = self.model_config.hf_text_config.qk_rope_head_dim + \
self.model_config.hf_text_config.kv_lora_rank
dsa_k_cache_factor = None
dsa_k_cache_size = None
if not self.model_config.is_deepseek_mla:
# for non-mla model, use FullAttentionSpec
k_tensor_split_factor = 2
v_tensor_split_factor = 2
elif self.use_sparse:
# for deepseek v3.2, DSA use FullAttentionSpec
# FullAttentionSpec allocate 2 * mla page size bytes,
# and we use half of that for k cache in DSA
dsa_k_cache_factor = 2
k_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.kv_lora_rank
v_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.qk_rope_head_dim
dsa_k_cache_size = int(kv_cache_tensor.size //
dsa_k_cache_factor)
else:
# for other deepseek models, use MLAAttentionSpec
k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank
v_tensor_split_factor = head_size / self.model_config.hf_text_config.qk_rope_head_dim
k_tensor_size = int(kv_cache_tensor.size //
k_tensor_split_factor)
v_tensor_size = int(kv_cache_tensor.size //
v_tensor_split_factor)
# for other attentions, e.g., self_attn, sliding window attn
if self.vllm_config.kv_transfer_config is None:
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
k_tensor = torch.zeros(k_tensor_size,
dtype=torch.int8,
device=self.device)
v_tensor = torch.zeros(kv_cache_tensor.size // 2,
v_tensor = torch.zeros(v_tensor_size,
dtype=torch.int8,
device=self.device)
#### k cache: for deepseek sparse attention
if dsa_k_cache_factor is not None:
dsa_k_cache_tensor = torch.zeros(
dsa_k_cache_size,
dtype=torch.int8,
device=self.device)
else:
cache_size = kv_cache_tensor.size // 2
cache_size_aligned = kv_cache_tensor.size // 2 + alignment
k_tensor = torch.zeros(cache_size_aligned,
k_tensor = torch.zeros(k_tensor_size + alignment,
dtype=torch.int8,
device=self.device)
v_tensor = torch.zeros(cache_size_aligned,
v_tensor = torch.zeros(v_tensor_size + alignment,
dtype=torch.int8,
device=self.device)
k_tensor = self._align_memory(k_tensor,
alignment)[:cache_size]
v_tensor = self._align_memory(v_tensor,
alignment)[:cache_size]
k_tensor = self._align_memory(
k_tensor, alignment)[:k_tensor_size]
v_tensor = self._align_memory(
v_tensor, alignment)[:v_tensor_size]
#### k cache: for deepseek sparse attention
if dsa_k_cache_factor is not None and dsa_k_cache_size is not None:
dsa_k_cache_tensor = torch.zeros(
dsa_k_cache_size + alignment,
dtype=torch.int8,
device=self.device)
dsa_k_cache_tensor = self._align_memory(
dsa_k_cache_tensor,
alignment)[:dsa_k_cache_size]
for layer_name_inner in kv_cache_tensor.shared_by:
# shared the kvcache between the self_attn specs in the same group
if ("attn" in layer_name_inner
and "linear_attn" not in layer_name_inner):
kv_cache_raw_tensors[layer_name_inner] = (k_tensor,
v_tensor)
kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor) if \
not self.use_sparse else (k_tensor, v_tensor, dsa_k_cache_tensor)
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
@@ -3426,6 +3304,24 @@ class NPUModelRunner(LoRAModelRunnerMixin):
assert layer_names == set(kv_cache_raw_tensors.keys(
)), "Some layers are not correctly initialized"
return kv_cache_raw_tensors
def _reshape_kv_cache_tensors(
self,
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
"""
Reshape the KV cache tensors to the desired shape and dtype.
Args:
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer, with
correct size but uninitialized shape.
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_caches: Dict[str, torch.Tensor] = {}
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
@@ -3437,14 +3333,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
if isinstance(kv_cache_spec, FullAttentionSpec):
raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name]
raw_dsa_k_tensor = None
if self.use_sparse:
raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name]
assert raw_dsa_k_tensor is not None
sum_page_size_bytes = raw_k_tensor.numel(
) + raw_v_tensor.numel() + raw_dsa_k_tensor.numel()
else:
raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore
layer_name]
sum_page_size_bytes = raw_k_tensor.numel(
) + raw_v_tensor.numel()
assert raw_k_tensor is not None
assert raw_v_tensor is not None
assert (raw_k_tensor.numel() + raw_v_tensor.numel()
) % kv_cache_spec.page_size_bytes == 0
num_blocks = (raw_k_tensor.numel() + raw_v_tensor.numel()
) // kv_cache_spec.page_size_bytes
assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0
num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes
# `num_blocks` is the number of blocks the model runner can use.
# `kv_cache_config.num_blocks` is the number of blocks that
@@ -3476,11 +3380,35 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
k_cache = raw_k_tensor.view(dtype).view(kv_cache_shape[1:])
if not self.model_config.is_deepseek_mla:
k_shape = kv_cache_shape[1:]
v_shape = k_shape
else:
# k_cache: nope_cache v_cache: rope_cache
mla_num_blocks, mla_block_size, num_kv_heads, _ = kv_cache_shape
k_shape = [
mla_num_blocks, mla_block_size, num_kv_heads,
self.model_config.hf_text_config.kv_lora_rank
]
v_shape = [
mla_num_blocks, mla_block_size, num_kv_heads,
self.model_config.hf_text_config.qk_rope_head_dim
]
k_cache = raw_k_tensor.view(dtype).view(k_shape)
k_cache = self._convert_torch_format(k_cache)
v_cache = raw_v_tensor.view(dtype).view(kv_cache_shape[1:])
v_cache = raw_v_tensor.view(dtype).view(v_shape)
v_cache = self._convert_torch_format(v_cache)
kv_caches[layer_name] = (k_cache, v_cache)
if self.use_sparse and raw_dsa_k_tensor is not None:
dsa_k_cache_shape = (num_blocks,
kv_cache_spec.block_size, 1, 128)
dsa_k_cache_size = (
num_blocks
) * kv_cache_spec.block_size * 128 * dtype.itemsize
dsa_k_cache = raw_dsa_k_tensor[:dsa_k_cache_size].view(
dtype).view(dsa_k_cache_shape)
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
else:
kv_caches[layer_name] = (k_cache, v_cache)
elif isinstance(kv_cache_spec, MambaSpec):
raw_tensor = kv_cache_raw_tensors[layer_name]
assert raw_tensor is not None
@@ -3521,10 +3449,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
raise ValueError("Unknown KV cache spec type.")
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def may_reinitialize_input_batch(self,