From 2d885869c558734dba9ffb42ea8e7c090bb04c22 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Wed, 24 Sep 2025 11:32:34 +0800 Subject: [PATCH] [KVCache][Bugfix] Fix kv cache initialization error of attention layer (#3113) ### What this PR does / why we need it? Fixes #3096 1. Fix kv cache initialization error of attention layer. There are some models with layer name like `attn.attn`, instead of `self_attn`, but the initialization of kv cache tensors only check for `self_attn` and `attn.attn`, which leding to the error `AssertionError: Some layers are not correctly initialized` 2. Set the default value of input arg `sampling_metadata` in `compute_logits` for the modeling files in vllm-ascend. Thus fixing the error `Qwen3NextForCausalLM.compute_logits() missing 1 required positional argument: 'sampling_metadata'` ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? test locally with internlm - vLLM version: v0.10.2 - vLLM main: https://github.com/vllm-project/vllm/commit/5aeb9254521023f97aca292b3478aa7ff485ffb2 --------- Signed-off-by: MengqingCao --- vllm_ascend/models/deepseek_mtp.py | 2 +- vllm_ascend/models/qwen3_next.py | 2 +- vllm_ascend/torchair/models/qwen2.py | 2 +- vllm_ascend/torchair/models/torchair_deepseek_mtp.py | 2 +- vllm_ascend/torchair/models/torchair_pangu_moe.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 8 +++++--- 6 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index b3daa6c..0c4f173 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -166,7 +166,7 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata, # type: ignore + sampling_metadata=None, # type: ignore spec_step_idx: int = 0, ) -> torch.Tensor: current_step_idx = (spec_step_idx % self.num_mtp_layers) diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index 7d1481e..175a529 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -986,7 +986,7 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata, # type: ignore + sampling_metadata=None, # type: ignore ) -> Optional[torch.Tensor]: return self.logits_processor(self.lm_head, hidden_states, sampling_metadata) diff --git a/vllm_ascend/torchair/models/qwen2.py b/vllm_ascend/torchair/models/qwen2.py index 56620dc..6e4990d 100644 --- a/vllm_ascend/torchair/models/qwen2.py +++ b/vllm_ascend/torchair/models/qwen2.py @@ -344,7 +344,7 @@ class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata, # type: ignore + sampling_metadata=None, # type: ignore ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) diff --git a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py index 6ef90f4..a7c5a6e 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py @@ -170,7 +170,7 @@ class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata, # type: ignore + sampling_metadata=None, # type: ignore spec_step_idx: int = 0, ) -> torch.Tensor: current_step_idx = (spec_step_idx % self.num_mtp_layers) diff --git a/vllm_ascend/torchair/models/torchair_pangu_moe.py b/vllm_ascend/torchair/models/torchair_pangu_moe.py index e38dc78..195ffde 100644 --- a/vllm_ascend/torchair/models/torchair_pangu_moe.py +++ b/vllm_ascend/torchair/models/torchair_pangu_moe.py @@ -936,7 +936,7 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP): def compute_logits( self, hidden_states: torch.Tensor, - sampling_metadata, # type: ignore + sampling_metadata=None, # type: ignore ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 98aeac6..f4656dd 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2784,9 +2784,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): for idx in range(len(kv_cache_tensor.shared_by)): layer_name = kv_cache_tensor.shared_by[idx] if "linear_attn" in layer_name: + # for mamba linear attention for layer_name_inner in kv_cache_tensor.shared_by: - if "self_attn" in layer_name_inner or layer_name_inner in kv_cache_raw_tensors.keys( - ): + if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \ + layer_name_inner in kv_cache_raw_tensors.keys(): continue if self.vllm_config.kv_transfer_config is None: tensor = torch.zeros(kv_cache_tensor.size, @@ -2800,7 +2801,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): tensor = self._align_memory( tensor, alignment)[:kv_cache_tensor.size] kv_cache_raw_tensors[layer_name_inner] = tensor - elif "self_attn" in layer_name: + elif "attn" in layer_name: + # for other attentions, e.g., self_attn, sliding window attn if self.vllm_config.kv_transfer_config is None: k_tensor = torch.zeros(kv_cache_tensor.size // 2, dtype=torch.int8,