[Refact.]: refactoring 310p-kv cache allocator, align with main branch (#6270)
### What this PR does / why we need it? refactoring 310p-kv cache allocator, align with main branch vLLM version: v0.14.0 vLLM main: https://github.com/vllm-project/vllm-ascend/pull/6270 Qwen2.5-7B E2E Test --------- Signed-off-by: pu-zhe <puzhe1@h-partners.com> Signed-off-by: pu-zhe <zpuaa@outlook.com> Co-authored-by: pu-zhe <puzhe1@h-partners.com>
This commit is contained in:
@@ -17,12 +17,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
from vllm.logger import logger
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec
|
||||
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
@@ -33,154 +31,157 @@ class NPUModelRunner310(NPUModelRunner):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._acl_format = ACL_FORMAT_FRACTAL_NZ
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
def initialize_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initialize KV cache tensors for 310P.
|
||||
Initialize the memory buffer for KV cache.
|
||||
|
||||
1) allocate buffers
|
||||
2) reshape / transform to the final layout
|
||||
3) optional cross-layer sharing
|
||||
4) bind buffers to the static forward context
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
# 310P limitation: KV transfer is not supported.
|
||||
if self.vllm_config.kv_transfer_config is not None:
|
||||
raise ValueError("KV cache transfer is not supported for 310P.")
|
||||
|
||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors_310p(kv_cache_config)
|
||||
kv_caches = self._reshape_kv_cache_tensors_310p(kv_cache_config, kv_cache_raw_tensors)
|
||||
|
||||
# Keep the same cross-layer KV cache sharing logic as the main branch.
|
||||
# For 310P, this is expected to be empty in most cases, but keeping it
|
||||
# makes the code path consistent and easier to reason about.
|
||||
if self.use_sparse:
|
||||
raise ValueError("Deepseek Sparse Attention is not supported for 310P.")
|
||||
if self.model_config.use_mla:
|
||||
raise ValueError("MLAAttention is not supported for 310P.")
|
||||
# Initialize the memory size for KV cache
|
||||
kv_cache_size = self._calculate_kv_cache_tensors_size(kv_cache_config)
|
||||
# Allocate and reshape KV cache Tensors
|
||||
kv_caches = self._allocate_kv_cache_and_reshape_tensors(kv_cache_config, kv_cache_size)
|
||||
# Set up cross-layer KV cache sharing
|
||||
for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
|
||||
logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
|
||||
# 310P devices do not support the "longcat_flash" special case here, so always be "1".
|
||||
bind_kv_cache(
|
||||
kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches,
|
||||
1,
|
||||
)
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def _allocate_kv_cache_tensors_310p(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
|
||||
def _calculate_kv_cache_tensors_size(self, kv_cache_config: KVCacheConfig) -> dict[str, int]:
|
||||
"""
|
||||
Allocate KV cache buffers for each attention layer.
|
||||
Initializes the KV cache size. The buffer needs to be reshaped to the desired shape before being used by
|
||||
the models.
|
||||
|
||||
Unlike the non-310p path, 310P uses torch.zeros directly with the final dtype,
|
||||
and defers layout casting (ACL format) to the reshape step.
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
dict[str, int]: A map between layer names to their
|
||||
corresponding memory buffer size.
|
||||
"""
|
||||
# Build a mapping: layer_name -> tensor_size(bytes).
|
||||
# init kv cache tensors
|
||||
kv_cache_sizes: dict[str, int] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
# 310P limitation: a KV cache tensor must not be shared by multiple layers.
|
||||
assert len(kv_cache_tensor.shared_by) == 1, (
|
||||
"KV cache tensor shared by multiple layers is not supported in 310P."
|
||||
)
|
||||
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
||||
# TODO: REFACTOR ME to sharing hybrid cache
|
||||
for idx in range(len(kv_cache_tensor.shared_by)):
|
||||
layer_name = kv_cache_tensor.shared_by[idx]
|
||||
if "linear_attn" in layer_name and layer_name not in kv_cache_sizes:
|
||||
# for mamba linear attention
|
||||
kv_cache_size = kv_cache_tensor.size
|
||||
for layer_name_inner in kv_cache_tensor.shared_by:
|
||||
# shared the kvcache between the self_attn specs in the same group
|
||||
if "linear_attn" in layer_name_inner:
|
||||
kv_cache_sizes[layer_name_inner] = kv_cache_size
|
||||
elif "attn" in layer_name and layer_name not in kv_cache_sizes:
|
||||
kv_tensor_split_factor = 2
|
||||
kv_tensor_size = int(kv_cache_tensor.size // kv_tensor_split_factor)
|
||||
for layer_name_inner in kv_cache_tensor.shared_by:
|
||||
# shared the kvcache between the self_attn specs in the same group
|
||||
if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner:
|
||||
kv_cache_sizes[layer_name_inner] = kv_tensor_size
|
||||
|
||||
kv_cache_raw_tensors: dict[str, tuple[torch.Tensor, torch.Tensor]] = {}
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_sizes.keys()), "Some layers are not correctly initialized"
|
||||
|
||||
return kv_cache_sizes
|
||||
|
||||
def _allocate_kv_cache_and_reshape_tensors(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_sizes: dict[str, int],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Allocate the KV cache tensors to the desired shape and dtype.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_sizes: The KV cache size of each layer
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
attn_backend = group.backend
|
||||
|
||||
if not isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
kv_tensor_size = kv_cache_sizes[layer_name]
|
||||
assert kv_tensor_size is not None
|
||||
sum_page_size_bytes = kv_tensor_size * 2
|
||||
assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
|
||||
if "attn" not in layer_name:
|
||||
continue
|
||||
if hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
|
||||
# Compute how many blocks this layer can hold.
|
||||
tensor_size = kv_cache_sizes[layer_name]
|
||||
assert tensor_size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks * block_size_chunk,
|
||||
block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
else:
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size
|
||||
)
|
||||
dtype = kv_cache_spec.dtype
|
||||
k_shape = kv_cache_shape[1:]
|
||||
v_shape = k_shape
|
||||
k_cache = torch_npu.empty_with_format(
|
||||
size=k_shape, dtype=dtype, device=self.device, acl_format=self._acl_format
|
||||
)
|
||||
v_cache = torch_npu.empty_with_format(
|
||||
size=v_shape, dtype=dtype, device=self.device, acl_format=self._acl_format
|
||||
)
|
||||
kv_caches[layer_name] = (k_cache, v_cache)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
tensor_size = kv_cache_sizes[layer_name]
|
||||
dtype = kv_cache_spec.dtype
|
||||
tensor_element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
raw_tensor = torch.zeros(tensor_size // tensor_element_size, dtype=dtype, device=self.device)
|
||||
assert tensor_size is not None
|
||||
assert tensor_size % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
|
||||
# `num_blocks` must be >= the number KVCacheManager may allocate.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
state_tensors = []
|
||||
target_idx = 0
|
||||
start_idx = 0
|
||||
for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes):
|
||||
# normally, there is conv state and ssm state in this loop. And there is only
|
||||
# a conv state in some special models.
|
||||
target_shape = (num_blocks, *shape)
|
||||
|
||||
# Determine the KV cache shape from backend.
|
||||
kv_cache_shape = self._get_kv_cache_shape_310p(
|
||||
attn_backend=attn_backend,
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
num_blocks=num_blocks,
|
||||
)
|
||||
|
||||
shape = kv_cache_shape[1:]
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
k_tensor = torch.zeros(shape, dtype=dtype, device=self.device)
|
||||
v_tensor = torch.zeros(shape, dtype=dtype, device=self.device)
|
||||
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)
|
||||
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
def _reshape_kv_cache_tensors_310p(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Transform allocated KV cache buffers into the final layout required by 310P.
|
||||
|
||||
For 310P, this mainly means casting tensors into the expected ACL format.
|
||||
"""
|
||||
kv_caches: dict[str, Any] = {}
|
||||
|
||||
for group in self._kv_cache_spec_attn_group_iterator():
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
if not isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
if "attn" not in layer_name:
|
||||
continue
|
||||
|
||||
k_tensor, v_tensor = kv_cache_raw_tensors[layer_name]
|
||||
|
||||
# In-place ACL layout cast to avoid the extra allocation of npu_format_cast,
|
||||
# which can spike peak memory (~2x KV cache) during initialization and trigger OOM.
|
||||
torch_npu.npu_format_cast_(k_tensor, self._acl_format)
|
||||
torch_npu.npu_format_cast_(v_tensor, self._acl_format)
|
||||
kv_caches[layer_name] = (k_tensor, v_tensor)
|
||||
target_idx += torch.prod(torch.tensor(target_shape)).item()
|
||||
tensor = raw_tensor[start_idx:target_idx].view(target_shape)
|
||||
start_idx = target_idx
|
||||
state_tensors.append(tensor)
|
||||
kv_caches[layer_name] = state_tensors
|
||||
else:
|
||||
raise ValueError("Unknown KV cache spec type.")
|
||||
|
||||
return kv_caches
|
||||
|
||||
def _get_kv_cache_shape_310p(
|
||||
self,
|
||||
attn_backend: Any,
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
num_blocks: int,
|
||||
) -> tuple[int, ...]:
|
||||
"""
|
||||
Compute KV cache shape with (optional) hybrid block support.
|
||||
"""
|
||||
if hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
return attn_backend.get_kv_cache_shape(
|
||||
num_blocks * block_size_chunk,
|
||||
block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
|
||||
return attn_backend.get_kv_cache_shape(
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user