diff --git a/vllm_ascend/_310p/model_runner_310p.py b/vllm_ascend/_310p/model_runner_310p.py index 7b115fb7..00da9b0f 100644 --- a/vllm_ascend/_310p/model_runner_310p.py +++ b/vllm_ascend/_310p/model_runner_310p.py @@ -17,12 +17,10 @@ from __future__ import annotations -from typing import Any - import torch import torch_npu -from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig -from vllm.v1.worker.utils import bind_kv_cache +from vllm.logger import logger +from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, MambaSpec from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner @@ -33,154 +31,157 @@ class NPUModelRunner310(NPUModelRunner): super().__init__(*args, **kwargs) self._acl_format = ACL_FORMAT_FRACTAL_NZ - def initialize_kv_cache_tensors( - self, - kv_cache_config: KVCacheConfig, - ) -> dict[str, Any]: + def initialize_kv_cache_tensors(self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ - Initialize KV cache tensors for 310P. + Initialize the memory buffer for KV cache. - 1) allocate buffers - 2) reshape / transform to the final layout - 3) optional cross-layer sharing - 4) bind buffers to the static forward context + Args: + kv_cache_config: The KV cache config + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. """ # 310P limitation: KV transfer is not supported. if self.vllm_config.kv_transfer_config is not None: raise ValueError("KV cache transfer is not supported for 310P.") - - kv_cache_raw_tensors = self._allocate_kv_cache_tensors_310p(kv_cache_config) - kv_caches = self._reshape_kv_cache_tensors_310p(kv_cache_config, kv_cache_raw_tensors) - - # Keep the same cross-layer KV cache sharing logic as the main branch. - # For 310P, this is expected to be empty in most cases, but keeping it - # makes the code path consistent and easier to reason about. + if self.use_sparse: + raise ValueError("Deepseek Sparse Attention is not supported for 310P.") + if self.model_config.use_mla: + raise ValueError("MLAAttention is not supported for 310P.") + # Initialize the memory size for KV cache + kv_cache_size = self._calculate_kv_cache_tensors_size(kv_cache_config) + # Allocate and reshape KV cache Tensors + kv_caches = self._allocate_kv_cache_and_reshape_tensors(kv_cache_config, kv_cache_size) + # Set up cross-layer KV cache sharing for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) kv_caches[layer_name] = kv_caches[target_layer_name] - # 310P devices do not support the "longcat_flash" special case here, so always be "1". - bind_kv_cache( - kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches, - 1, - ) + from vllm.v1.worker.utils import bind_kv_cache + + bind_kv_cache(kv_caches, self.compilation_config.static_forward_context, self.kv_caches) return kv_caches - def _allocate_kv_cache_tensors_310p( - self, - kv_cache_config: KVCacheConfig, - ) -> dict[str, tuple[torch.Tensor, torch.Tensor]]: + def _calculate_kv_cache_tensors_size(self, kv_cache_config: KVCacheConfig) -> dict[str, int]: """ - Allocate KV cache buffers for each attention layer. + Initializes the KV cache size. The buffer needs to be reshaped to the desired shape before being used by + the models. - Unlike the non-310p path, 310P uses torch.zeros directly with the final dtype, - and defers layout casting (ACL format) to the reshape step. + Args: + kv_cache_config: The KV cache config + Returns: + dict[str, int]: A map between layer names to their + corresponding memory buffer size. """ - # Build a mapping: layer_name -> tensor_size(bytes). + # init kv cache tensors kv_cache_sizes: dict[str, int] = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - # 310P limitation: a KV cache tensor must not be shared by multiple layers. - assert len(kv_cache_tensor.shared_by) == 1, ( - "KV cache tensor shared by multiple layers is not supported in 310P." - ) - kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size + # TODO: REFACTOR ME to sharing hybrid cache + for idx in range(len(kv_cache_tensor.shared_by)): + layer_name = kv_cache_tensor.shared_by[idx] + if "linear_attn" in layer_name and layer_name not in kv_cache_sizes: + # for mamba linear attention + kv_cache_size = kv_cache_tensor.size + for layer_name_inner in kv_cache_tensor.shared_by: + # shared the kvcache between the self_attn specs in the same group + if "linear_attn" in layer_name_inner: + kv_cache_sizes[layer_name_inner] = kv_cache_size + elif "attn" in layer_name and layer_name not in kv_cache_sizes: + kv_tensor_split_factor = 2 + kv_tensor_size = int(kv_cache_tensor.size // kv_tensor_split_factor) + for layer_name_inner in kv_cache_tensor.shared_by: + # shared the kvcache between the self_attn specs in the same group + if "attn" in layer_name_inner and "linear_attn" not in layer_name_inner: + kv_cache_sizes[layer_name_inner] = kv_tensor_size - kv_cache_raw_tensors: dict[str, tuple[torch.Tensor, torch.Tensor]] = {} + layer_names = set() + for group in kv_cache_config.kv_cache_groups: + for layer_name in group.layer_names: + if layer_name in self.runner_only_attn_layers: + continue + layer_names.add(layer_name) + assert layer_names == set(kv_cache_sizes.keys()), "Some layers are not correctly initialized" + return kv_cache_sizes + + def _allocate_kv_cache_and_reshape_tensors( + self, + kv_cache_config: KVCacheConfig, + kv_cache_sizes: dict[str, int], + ) -> dict[str, torch.Tensor]: + """ + Allocate the KV cache tensors to the desired shape and dtype. + + Args: + kv_cache_config: The KV cache config + kv_cache_sizes: The KV cache size of each layer + Returns: + dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_caches: dict[str, torch.Tensor] = {} for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec attn_backend = group.backend - - if not isinstance(kv_cache_spec, FullAttentionSpec): - raise ValueError("Unknown KV cache spec type.") - for layer_name in group.layer_names: if layer_name in self.runner_only_attn_layers: continue + if isinstance(kv_cache_spec, AttentionSpec): + kv_tensor_size = kv_cache_sizes[layer_name] + assert kv_tensor_size is not None + sum_page_size_bytes = kv_tensor_size * 2 + assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0 + num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes + assert num_blocks >= kv_cache_config.num_blocks - if "attn" not in layer_name: - continue + if hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks: + block_size = attn_backend.get_supported_block_size()[0] - # Compute how many blocks this layer can hold. - tensor_size = kv_cache_sizes[layer_name] - assert tensor_size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_size // kv_cache_spec.page_size_bytes + block_size_chunk = kv_cache_spec.block_size // block_size + kv_cache_shape = attn_backend.get_kv_cache_shape( + num_blocks * block_size_chunk, + block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + ) + else: + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size + ) + dtype = kv_cache_spec.dtype + k_shape = kv_cache_shape[1:] + v_shape = k_shape + k_cache = torch_npu.empty_with_format( + size=k_shape, dtype=dtype, device=self.device, acl_format=self._acl_format + ) + v_cache = torch_npu.empty_with_format( + size=v_shape, dtype=dtype, device=self.device, acl_format=self._acl_format + ) + kv_caches[layer_name] = (k_cache, v_cache) + elif isinstance(kv_cache_spec, MambaSpec): + tensor_size = kv_cache_sizes[layer_name] + dtype = kv_cache_spec.dtype + tensor_element_size = torch.tensor([], dtype=dtype).element_size() + raw_tensor = torch.zeros(tensor_size // tensor_element_size, dtype=dtype, device=self.device) + assert tensor_size is not None + assert tensor_size % kv_cache_spec.page_size_bytes == 0 + num_blocks = tensor_size // kv_cache_spec.page_size_bytes + assert num_blocks >= kv_cache_config.num_blocks - # `num_blocks` must be >= the number KVCacheManager may allocate. - assert num_blocks >= kv_cache_config.num_blocks + state_tensors = [] + target_idx = 0 + start_idx = 0 + for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): + # normally, there is conv state and ssm state in this loop. And there is only + # a conv state in some special models. + target_shape = (num_blocks, *shape) - # Determine the KV cache shape from backend. - kv_cache_shape = self._get_kv_cache_shape_310p( - attn_backend=attn_backend, - kv_cache_spec=kv_cache_spec, - num_blocks=num_blocks, - ) - - shape = kv_cache_shape[1:] - dtype = kv_cache_spec.dtype - - k_tensor = torch.zeros(shape, dtype=dtype, device=self.device) - v_tensor = torch.zeros(shape, dtype=dtype, device=self.device) - kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor) - - return kv_cache_raw_tensors - - def _reshape_kv_cache_tensors_310p( - self, - kv_cache_config: KVCacheConfig, - kv_cache_raw_tensors: dict[str, tuple[torch.Tensor, torch.Tensor]], - ) -> dict[str, Any]: - """ - Transform allocated KV cache buffers into the final layout required by 310P. - - For 310P, this mainly means casting tensors into the expected ACL format. - """ - kv_caches: dict[str, Any] = {} - - for group in self._kv_cache_spec_attn_group_iterator(): - kv_cache_spec = group.kv_cache_spec - if not isinstance(kv_cache_spec, FullAttentionSpec): - raise ValueError("Unknown KV cache spec type.") - - for layer_name in group.layer_names: - if layer_name in self.runner_only_attn_layers: - continue - if "attn" not in layer_name: - continue - - k_tensor, v_tensor = kv_cache_raw_tensors[layer_name] - - # In-place ACL layout cast to avoid the extra allocation of npu_format_cast, - # which can spike peak memory (~2x KV cache) during initialization and trigger OOM. - torch_npu.npu_format_cast_(k_tensor, self._acl_format) - torch_npu.npu_format_cast_(v_tensor, self._acl_format) - kv_caches[layer_name] = (k_tensor, v_tensor) + target_idx += torch.prod(torch.tensor(target_shape)).item() + tensor = raw_tensor[start_idx:target_idx].view(target_shape) + start_idx = target_idx + state_tensors.append(tensor) + kv_caches[layer_name] = state_tensors + else: + raise ValueError("Unknown KV cache spec type.") return kv_caches - - def _get_kv_cache_shape_310p( - self, - attn_backend: Any, - kv_cache_spec: FullAttentionSpec, - num_blocks: int, - ) -> tuple[int, ...]: - """ - Compute KV cache shape with (optional) hybrid block support. - """ - if hasattr(attn_backend, "get_supported_block_size") and self.use_hybrid_blocks: - block_size = attn_backend.get_supported_block_size()[0] - block_size_chunk = kv_cache_spec.block_size // block_size - return attn_backend.get_kv_cache_shape( - num_blocks * block_size_chunk, - block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - ) - - return attn_backend.get_kv_cache_shape( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - )