[Refact.]: refactoring 310p-kv cache allocator, align with main branch (#6270)

### What this PR does / why we need it?
refactoring 310p-kv cache allocator, align with main branch

vLLM version: v0.14.0
vLLM main: https://github.com/vllm-project/vllm-ascend/pull/6270
Qwen2.5-7B E2E Test

---------

Signed-off-by: pu-zhe <puzhe1@h-partners.com>
Signed-off-by: pu-zhe <zpuaa@outlook.com>
Co-authored-by: pu-zhe <puzhe1@h-partners.com>
This commit is contained in:
pu-zhe
2026-01-27 16:26:48 +08:00
committed by GitHub
parent 5e34c70ffc
commit 57fd6e4bd9

View File

@@ -17,12 +17,10 @@
from __future__ import annotations
from typing import Any
import torch
import torch_npu
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.worker.utils import bind_kv_cache
from vllm.logger import logger
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -33,154 +31,157 @@ class NPUModelRunner310(NPUModelRunner):
super().__init__(*args, **kwargs)
self._acl_format = ACL_FORMAT_FRACTAL_NZ
def initialize_kv_cache_tensors(
self,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
def initialize_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
Initialize KV cache tensors for 310P.
Initialize the memory buffer for KV cache.
1) allocate buffers
2) reshape / transform to the final layout
3) optional cross-layer sharing
4) bind buffers to the static forward context
Args:
kv_cache_config: The KV cache config
Returns:
Dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
# 310P limitation: KV transfer is not supported.
if self.vllm_config.kv_transfer_config is not None:
raise ValueError("KV cache transfer is not supported for 310P.")
kv_cache_raw_tensors = self._allocate_kv_cache_tensors_310p(kv_cache_config)
kv_caches = self._reshape_kv_cache_tensors_310p(kv_cache_config, kv_cache_raw_tensors)
# Keep the same cross-layer KV cache sharing logic as the main branch.
# For 310P, this is expected to be empty in most cases, but keeping it
# makes the code path consistent and easier to reason about.
if self.use_sparse:
raise ValueError("Deepseek Sparse Attention is not supported for 310P.")
if self.model_config.use_mla:
raise ValueError("MLAAttention is not supported for 310P.")
# Initialize the memory size for KV cache
kv_cache_size = self._calculate_kv_cache_tensors_size(kv_cache_config)
# Allocate and reshape KV cache Tensors
kv_caches = self._allocate_kv_cache_and_reshape_tensors(kv_cache_config, kv_cache_size)
# Set up cross-layer KV cache sharing
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
kv_caches[layer_name] = kv_caches[target_layer_name]
# 310P devices do not support the "longcat_flash" special case here, so always be "1".
bind_kv_cache(
kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches,
1,
)
from vllm.v1.worker.utils import bind_kv_cache
bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches)
return kv_caches
def _allocate_kv_cache_tensors_310p(
self,
kv_cache_config: KVCacheConfig,
) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
def _calculate_kv_cache_tensors_size(self, kv_cache_config: KVCacheConfig) -> dict[str, int]:
"""
Allocate KV cache buffers for each attention layer.
Initializes the KV cache size. The buffer needs to be reshaped to the desired shape before being used by
the models.
Unlike the non-310p path, 310P uses torch.zeros directly with the final dtype,
and defers layout casting (ACL format) to the reshape step.
Args:
kv_cache_config: The KV cache config
Returns:
dict[str, int]: A map between layer names to their
corresponding memory buffer size.
"""
# Build a mapping: layer_name -> tensor_size(bytes).
# init kv cache tensors
kv_cache_sizes: dict[str, int] = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
# 310P limitation: a KV cache tensor must not be shared by multiple layers.
assert len(kv_cache_tensor.shared_by) == 1, (
"KV cache tensor shared by multiple layers is not supported in 310P."
)
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
# TODO: REFACTOR ME to sharing hybrid cache
for idx in range(len(kv_cache_tensor.shared_by)):
layer_name = kv_cache_tensor.shared_by[idx]
if "linear_attn" in layer_name and layer_name not in kv_cache_sizes:
# for mamba linear attention
kv_cache_size = kv_cache_tensor.size
for layer_name_inner in kv_cache_tensor.shared_by:
# shared the kvcache between the self_attn specs in the same group
if "linear_attn" in layer_name_inner:
kv_cache_sizes[layer_name_inner] = kv_cache_size
elif "attn" in layer_name and layer_name not in kv_cache_sizes:
kv_tensor_split_factor = 2
kv_tensor_size = int(kv_cache_tensor.size // kv_tensor_split_factor)
for layer_name_inner in kv_cache_tensor.shared_by:
# shared the kvcache between the self_attn specs in the same group
if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner:
kv_cache_sizes[layer_name_inner] = kv_tensor_size
kv_cache_raw_tensors: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
layer_names = set()
for group in kv_cache_config.kv_cache_groups:
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
layer_names.add(layer_name)
assert layer_names == set(kv_cache_sizes.keys()), "Some layers are not correctly initialized"
return kv_cache_sizes
def _allocate_kv_cache_and_reshape_tensors(
self,
kv_cache_config: KVCacheConfig,
kv_cache_sizes: dict[str, int],
) -> dict[str, torch.Tensor]:
"""
Allocate the KV cache tensors to the desired shape and dtype.
Args:
kv_cache_config: The KV cache config
kv_cache_sizes: The KV cache size of each layer
Returns:
dict[str, torch.Tensor]: A map between layer names to their
corresponding memory buffer for KV cache.
"""
kv_caches: dict[str, torch.Tensor] = {}
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
if not isinstance(kv_cache_spec, FullAttentionSpec):
raise ValueError("Unknown KV cache spec type.")
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
if isinstance(kv_cache_spec, AttentionSpec):
kv_tensor_size = kv_cache_sizes[layer_name]
assert kv_tensor_size is not None
sum_page_size_bytes = kv_tensor_size * 2
assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0
num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes
assert num_blocks >= kv_cache_config.num_blocks
if "attn" not in layer_name:
continue
if hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks:
block_size = attn_backend.get_supported_block_size()[0]
# Compute how many blocks this layer can hold.
tensor_size = kv_cache_sizes[layer_name]
assert tensor_size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
block_size_chunk = kv_cache_spec.block_size // block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks * block_size_chunk,
block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
)
else:
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size
)
dtype = kv_cache_spec.dtype
k_shape = kv_cache_shape[1:]
v_shape = k_shape
k_cache = torch_npu.empty_with_format(
size=k_shape, dtype=dtype, device=self.device, acl_format=self._acl_format
)
v_cache = torch_npu.empty_with_format(
size=v_shape, dtype=dtype, device=self.device, acl_format=self._acl_format
)
kv_caches[layer_name] = (k_cache, v_cache)
elif isinstance(kv_cache_spec, MambaSpec):
tensor_size = kv_cache_sizes[layer_name]
dtype = kv_cache_spec.dtype
tensor_element_size = torch.tensor([], dtype=dtype).element_size()
raw_tensor = torch.zeros(tensor_size // tensor_element_size, dtype=dtype, device=self.device)
assert tensor_size is not None
assert tensor_size % kv_cache_spec.page_size_bytes == 0
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
assert num_blocks >= kv_cache_config.num_blocks
# `num_blocks` must be >= the number KVCacheManager may allocate.
assert num_blocks >= kv_cache_config.num_blocks
state_tensors = []
target_idx = 0
start_idx = 0
for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes):
# normally, there is conv state and ssm state in this loop. And there is only
# a conv state in some special models.
target_shape = (num_blocks, *shape)
# Determine the KV cache shape from backend.
kv_cache_shape = self._get_kv_cache_shape_310p(
attn_backend=attn_backend,
kv_cache_spec=kv_cache_spec,
num_blocks=num_blocks,
)
shape = kv_cache_shape[1:]
dtype = kv_cache_spec.dtype
k_tensor = torch.zeros(shape, dtype=dtype, device=self.device)
v_tensor = torch.zeros(shape, dtype=dtype, device=self.device)
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)
return kv_cache_raw_tensors
def _reshape_kv_cache_tensors_310p(
self,
kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, tuple[torch.Tensor, torch.Tensor]],
) -> dict[str, Any]:
"""
Transform allocated KV cache buffers into the final layout required by 310P.
For 310P, this mainly means casting tensors into the expected ACL format.
"""
kv_caches: dict[str, Any] = {}
for group in self._kv_cache_spec_attn_group_iterator():
kv_cache_spec = group.kv_cache_spec
if not isinstance(kv_cache_spec, FullAttentionSpec):
raise ValueError("Unknown KV cache spec type.")
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
if "attn" not in layer_name:
continue
k_tensor, v_tensor = kv_cache_raw_tensors[layer_name]
# In-place ACL layout cast to avoid the extra allocation of npu_format_cast,
# which can spike peak memory (~2x KV cache) during initialization and trigger OOM.
torch_npu.npu_format_cast_(k_tensor, self._acl_format)
torch_npu.npu_format_cast_(v_tensor, self._acl_format)
kv_caches[layer_name] = (k_tensor, v_tensor)
target_idx += torch.prod(torch.tensor(target_shape)).item()
tensor = raw_tensor[start_idx:target_idx].view(target_shape)
start_idx = target_idx
state_tensors.append(tensor)
kv_caches[layer_name] = state_tensors
else:
raise ValueError("Unknown KV cache spec type.")
return kv_caches
def _get_kv_cache_shape_310p(
self,
attn_backend: Any,
kv_cache_spec: FullAttentionSpec,
num_blocks: int,
) -> tuple[int, ...]:
"""
Compute KV cache shape with (optional) hybrid block support.
"""
if hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks:
block_size = attn_backend.get_supported_block_size()[0]
block_size_chunk = kv_cache_spec.block_size // block_size
return attn_backend.get_kv_cache_shape(
num_blocks * block_size_chunk,
block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
)
return attn_backend.get_kv_cache_shape(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
)