From 5fed166a99cd35f6c8712d1d5144dee7aa247298 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Tue, 4 Nov 2025 17:26:54 +0800 Subject: [PATCH] [ModelRunner][Refactor] Refactor kv cache tensor initialization logic (#3106) ### What this PR does / why we need it? Refactor kv cache tensor initialization logic. 1. Unify the kvcache tensor initialization logic of deepseek and normal models 2. spilt `initialize_kv_cache_tensors` into `_allocate_kv_cache_tensors` and `_reshape_kv_cache_tensors`, following gpu modelrunner in vllm ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with existing test. 1. prefill disaggregation scenario 4. deepseek + aclgraph/eager mode 5. qwen3 next - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac --------- Signed-off-by: MengqingCao --- vllm_ascend/worker/model_runner_v1.py | 376 ++++++++++---------------- 1 file changed, 150 insertions(+), 226 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 4e88b405..f6d3bb20 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -3138,15 +3138,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): ]) self.may_reinitialize_input_batch(kv_cache_config) - - if self.use_sparse: - kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa( - kv_cache_config) - elif self.model_config.is_deepseek_mla: - kv_caches = self.initialize_kv_cache_tensors_deepseek_mla( - kv_cache_config) - else: - kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) + kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) @@ -3158,197 +3150,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): offset = (aligned_addr - data_ptr) // tensor.element_size() return tensor[int(offset):] - def initialize_kv_cache_tensors_deepseek_sfa( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: - kv_cache_sizes = {} - for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - assert len(kv_cache_tensor.shared_by) == 1, ( - "KV cache tensor shared by multiple layers is not supported in " - "NPU.") - kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size - - kv_caches: Dict[str, torch.Tensor] = {} - for group in self._kv_cache_spec_attn_group_iterator(): - kv_cache_spec = group.kv_cache_spec - attn_backend = group.backend - for layer_name in group.layer_names: - if layer_name in self.runner_only_attn_layers: - continue - tensor_size = kv_cache_sizes[layer_name] - num_blocks = tensor_size // kv_cache_spec.page_size_bytes - if self.vllm_config.additional_config.get( - "kv_cache_dtype", None) == 'int8': - kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - elif hasattr( - attn_backend, "get_supported_block_size" - ) and not self.model_config.is_deepseek_mla and not self.use_sparse: - block_size = attn_backend.get_supported_block_size()[0] - block_size_chunk = kv_cache_spec.block_size // block_size - kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks * block_size_chunk, block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - else: - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - dtype = kv_cache_spec.dtype - - alignment = 2 * 1024 * 1024 - num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape - rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - nope_dim = head_size - rope_dim - nope_cache_shape = (num_blocks, block_size, num_kv_heads, - nope_dim) - rope_cache_shape = (num_blocks, block_size, num_kv_heads, - rope_dim) - #### k cache - # TODO(zzzzwwjj): wait transformers add these params - k_cache_shape = (num_blocks, block_size, 1, 128) - if self.vllm_config.kv_transfer_config is None: - # For no disaggregate pd scenario, allocate kv cache in normal way - rope_cache = torch.zeros(rope_cache_shape, - dtype=dtype, - device=self.device) - nope_cache = torch.zeros(nope_cache_shape, - dtype=dtype, - device=self.device) - rope_cache = self._convert_torch_format(rope_cache) - nope_cache = self._convert_torch_format(nope_cache) - - #### k cache - k_cache = torch.zeros(k_cache_shape, - dtype=dtype, - device=self.device) - k_cache = self._convert_torch_format(k_cache) - else: - - # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory - # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but - # we found there are also some exceptions during test, so we manual align those memory here, this part - # of code may consume 2M * 2 * elem_size memory every layer. - nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim - nope_allocate_shape_alignment = nope_allocate_shape + alignment - rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim - rope_allocate_shape_alignment = rope_allocate_shape + alignment - - nope_cache = torch.zeros(nope_allocate_shape_alignment, - dtype=dtype, - device=self.device) - rope_cache = torch.zeros(rope_allocate_shape_alignment, - dtype=dtype, - device=self.device) - #### k cache - # TODO(zzzzwwjj): wait transformers add these params - k_allocate_shape = num_blocks * block_size * 1 * 128 - k_allocate_shape_alignment = k_allocate_shape + alignment - k_cache = torch.zeros(k_allocate_shape_alignment, - dtype=dtype, - device=self.device) - - nope_cache = self._align_memory( - nope_cache, - alignment)[:nope_allocate_shape].view(nope_cache_shape) - rope_cache = self._align_memory( - rope_cache, - alignment)[:rope_allocate_shape].view(rope_cache_shape) - k_cache = self._align_memory( - k_cache, - alignment)[:k_allocate_shape].view(k_cache_shape) - - kv_caches[layer_name] = (nope_cache, rope_cache, k_cache) - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches) - - return kv_caches - - def initialize_kv_cache_tensors_deepseek_mla( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: - kv_cache_sizes = {} - for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - assert len(kv_cache_tensor.shared_by) == 1, ( - "KV cache tensor shared by multiple layers is not supported in " - "NPU.") - kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size - - kv_caches: Dict[str, torch.Tensor] = {} - for group in self._kv_cache_spec_attn_group_iterator(): - kv_cache_spec = group.kv_cache_spec - attn_backend = group.backend - for layer_name in group.layer_names: - if layer_name in self.runner_only_attn_layers: - continue - tensor_size = kv_cache_sizes[layer_name] - num_blocks = tensor_size // kv_cache_spec.page_size_bytes - if self.vllm_config.additional_config.get( - "kv_cache_dtype", None) == 'int8': - kv_cache_shape = attn_backend.get_bsh_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - elif hasattr(attn_backend, "get_supported_block_size" - ) and not self.model_config.is_deepseek_mla: - block_size = attn_backend.get_supported_block_size()[0] - block_size_chunk = kv_cache_spec.block_size // block_size - kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks * block_size_chunk, block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - else: - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) - dtype = kv_cache_spec.dtype - - alignment = 2 * 1024 * 1024 - num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape - rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - nope_dim = head_size - rope_dim - nope_cache_shape = (num_blocks, block_size, num_kv_heads, - nope_dim) - rope_cache_shape = (num_blocks, block_size, num_kv_heads, - rope_dim) - if self.vllm_config.kv_transfer_config is None: - # For no disaggregate pd scenario, allocate kv cache in normal way - rope_cache = torch.zeros(rope_cache_shape, - dtype=dtype, - device=self.device) - nope_cache = torch.zeros(nope_cache_shape, - dtype=dtype, - device=self.device) - rope_cache = self._convert_torch_format(rope_cache) - nope_cache = self._convert_torch_format(nope_cache) - else: - - # In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory - # address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but - # we found there are also some exceptions during test, so we manual align those memory here, this part - # of code may consume 2M * 2 * elem_size memory every layer. - nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim - nope_allocate_shape_alignment = nope_allocate_shape + alignment - rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim - rope_allocate_shape_alignment = rope_allocate_shape + alignment - - nope_cache = torch.zeros(nope_allocate_shape_alignment, - dtype=dtype, - device=self.device) - rope_cache = torch.zeros(rope_allocate_shape_alignment, - dtype=dtype, - device=self.device) - nope_cache = self._align_memory( - nope_cache, - alignment)[:nope_allocate_shape].view(nope_cache_shape) - rope_cache = self._align_memory( - rope_cache, - alignment)[:rope_allocate_shape].view(rope_cache_shape) - kv_caches[layer_name] = (nope_cache, rope_cache) - - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches) - - return kv_caches - def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ @@ -3360,6 +3161,34 @@ class NPUModelRunner(LoRAModelRunnerMixin): Dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. """ + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, + kv_cache_raw_tensors) + + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) + return kv_caches + + def _allocate_kv_cache_tensors( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + """ + Initializes the KV cache buffer with the correct size. The buffer needs + to be reshaped to the desired shape before being used by the models. + + NOTE: To support prefill disaggregation, we need to split kvcache tensor into + k_cahce and v cache, and the addr of both are aligned by 2M + + Args: + kv_cache_config: The KV cache config + Returns: + dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + dict[str, tuple(torch.Tensor, torch.Tensor)] A map between layer names + to their corresponding memory buffer for K cache and V cache. + """ # init kv cache tensors kv_cache_raw_tensors: dict[str, Union[torch.Tensor, Optional[torch.Tensor]]] = {} @@ -3383,39 +3212,88 @@ class NPUModelRunner(LoRAModelRunnerMixin): device=self.device) tensor = self._align_memory( tensor, alignment)[:kv_cache_tensor.size] + for layer_name_inner in kv_cache_tensor.shared_by: - # shared the kvcache between the linear_attn specs in the same group + # shared the kvcache between the self_attn specs in the same group if "linear_attn" in layer_name_inner: kv_cache_raw_tensors[layer_name_inner] = tensor elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys( ): + # NOTE: We need to init k cache tensor (nope cache tensor in mla) and + # v cache tensor (rope cache tensor in mla) separately to support llmdatadist, + # as it only support the 0-dim of kv_cache is `num_blocks`. + # For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim + # and rope head dim. + if self.model_config.is_deepseek_mla: + head_size = self.model_config.hf_text_config.qk_rope_head_dim + \ + self.model_config.hf_text_config.kv_lora_rank + + dsa_k_cache_factor = None + dsa_k_cache_size = None + if not self.model_config.is_deepseek_mla: + # for non-mla model, use FullAttentionSpec + k_tensor_split_factor = 2 + v_tensor_split_factor = 2 + elif self.use_sparse: + # for deepseek v3.2, DSA use FullAttentionSpec + # FullAttentionSpec allocate 2 * mla page size bytes, + # and we use half of that for k cache in DSA + dsa_k_cache_factor = 2 + k_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.kv_lora_rank + v_tensor_split_factor = 2 * head_size / self.model_config.hf_text_config.qk_rope_head_dim + dsa_k_cache_size = int(kv_cache_tensor.size // + dsa_k_cache_factor) + else: + # for other deepseek models, use MLAAttentionSpec + k_tensor_split_factor = head_size / self.model_config.hf_text_config.kv_lora_rank + v_tensor_split_factor = head_size / self.model_config.hf_text_config.qk_rope_head_dim + + k_tensor_size = int(kv_cache_tensor.size // + k_tensor_split_factor) + v_tensor_size = int(kv_cache_tensor.size // + v_tensor_split_factor) + # for other attentions, e.g., self_attn, sliding window attn if self.vllm_config.kv_transfer_config is None: - k_tensor = torch.zeros(kv_cache_tensor.size // 2, + k_tensor = torch.zeros(k_tensor_size, dtype=torch.int8, device=self.device) - v_tensor = torch.zeros(kv_cache_tensor.size // 2, + v_tensor = torch.zeros(v_tensor_size, dtype=torch.int8, device=self.device) + #### k cache: for deepseek sparse attention + if dsa_k_cache_factor is not None: + dsa_k_cache_tensor = torch.zeros( + dsa_k_cache_size, + dtype=torch.int8, + device=self.device) else: - cache_size = kv_cache_tensor.size // 2 - cache_size_aligned = kv_cache_tensor.size // 2 + alignment - k_tensor = torch.zeros(cache_size_aligned, + k_tensor = torch.zeros(k_tensor_size + alignment, dtype=torch.int8, device=self.device) - v_tensor = torch.zeros(cache_size_aligned, + v_tensor = torch.zeros(v_tensor_size + alignment, dtype=torch.int8, device=self.device) - k_tensor = self._align_memory(k_tensor, - alignment)[:cache_size] - v_tensor = self._align_memory(v_tensor, - alignment)[:cache_size] + k_tensor = self._align_memory( + k_tensor, alignment)[:k_tensor_size] + v_tensor = self._align_memory( + v_tensor, alignment)[:v_tensor_size] + #### k cache: for deepseek sparse attention + if dsa_k_cache_factor is not None and dsa_k_cache_size is not None: + dsa_k_cache_tensor = torch.zeros( + dsa_k_cache_size + alignment, + dtype=torch.int8, + device=self.device) + dsa_k_cache_tensor = self._align_memory( + dsa_k_cache_tensor, + alignment)[:dsa_k_cache_size] + for layer_name_inner in kv_cache_tensor.shared_by: # shared the kvcache between the self_attn specs in the same group if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner): - kv_cache_raw_tensors[layer_name_inner] = (k_tensor, - v_tensor) + kv_cache_raw_tensors[layer_name_inner] = (k_tensor, v_tensor) if \ + not self.use_sparse else (k_tensor, v_tensor, dsa_k_cache_tensor) layer_names = set() for group in kv_cache_config.kv_cache_groups: @@ -3426,6 +3304,24 @@ class NPUModelRunner(LoRAModelRunnerMixin): assert layer_names == set(kv_cache_raw_tensors.keys( )), "Some layers are not correctly initialized" + return kv_cache_raw_tensors + + def _reshape_kv_cache_tensors( + self, + kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + """ + Reshape the KV cache tensors to the desired shape and dtype. + + Args: + kv_cache_config: The KV cache config + kv_cache_raw_tensors: The KV cache buffer of each layer, with + correct size but uninitialized shape. + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ kv_caches: Dict[str, torch.Tensor] = {} for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec @@ -3437,14 +3333,22 @@ class NPUModelRunner(LoRAModelRunnerMixin): # TODO: remove this after the OOM issue is located and fixed, otherwise, some model may # encounter OOM issue if isinstance(kv_cache_spec, FullAttentionSpec): - raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore - layer_name] + raw_dsa_k_tensor = None + if self.use_sparse: + raw_k_tensor, raw_v_tensor, raw_dsa_k_tensor = kv_cache_raw_tensors[ # type: ignore + layer_name] + assert raw_dsa_k_tensor is not None + sum_page_size_bytes = raw_k_tensor.numel( + ) + raw_v_tensor.numel() + raw_dsa_k_tensor.numel() + else: + raw_k_tensor, raw_v_tensor = kv_cache_raw_tensors[ # type: ignore + layer_name] + sum_page_size_bytes = raw_k_tensor.numel( + ) + raw_v_tensor.numel() assert raw_k_tensor is not None assert raw_v_tensor is not None - assert (raw_k_tensor.numel() + raw_v_tensor.numel() - ) % kv_cache_spec.page_size_bytes == 0 - num_blocks = (raw_k_tensor.numel() + raw_v_tensor.numel() - ) // kv_cache_spec.page_size_bytes + assert sum_page_size_bytes % kv_cache_spec.page_size_bytes == 0 + num_blocks = sum_page_size_bytes // kv_cache_spec.page_size_bytes # `num_blocks` is the number of blocks the model runner can use. # `kv_cache_config.num_blocks` is the number of blocks that @@ -3476,11 +3380,35 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - k_cache = raw_k_tensor.view(dtype).view(kv_cache_shape[1:]) + if not self.model_config.is_deepseek_mla: + k_shape = kv_cache_shape[1:] + v_shape = k_shape + else: + # k_cache: nope_cache v_cache: rope_cache + mla_num_blocks, mla_block_size, num_kv_heads, _ = kv_cache_shape + k_shape = [ + mla_num_blocks, mla_block_size, num_kv_heads, + self.model_config.hf_text_config.kv_lora_rank + ] + v_shape = [ + mla_num_blocks, mla_block_size, num_kv_heads, + self.model_config.hf_text_config.qk_rope_head_dim + ] + k_cache = raw_k_tensor.view(dtype).view(k_shape) k_cache = self._convert_torch_format(k_cache) - v_cache = raw_v_tensor.view(dtype).view(kv_cache_shape[1:]) + v_cache = raw_v_tensor.view(dtype).view(v_shape) v_cache = self._convert_torch_format(v_cache) - kv_caches[layer_name] = (k_cache, v_cache) + if self.use_sparse and raw_dsa_k_tensor is not None: + dsa_k_cache_shape = (num_blocks, + kv_cache_spec.block_size, 1, 128) + dsa_k_cache_size = ( + num_blocks + ) * kv_cache_spec.block_size * 128 * dtype.itemsize + dsa_k_cache = raw_dsa_k_tensor[:dsa_k_cache_size].view( + dtype).view(dsa_k_cache_shape) + kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache) + else: + kv_caches[layer_name] = (k_cache, v_cache) elif isinstance(kv_cache_spec, MambaSpec): raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor is not None @@ -3521,10 +3449,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): else: raise ValueError("Unknown KV cache spec type.") - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches) - return kv_caches def may_reinitialize_input_batch(self,