From 1c0ecf806ad871b05875c835fe3ae22d6e60ead7 Mon Sep 17 00:00:00 2001 From: aipaes <82140963+aipaes@users.noreply.github.com> Date: Fri, 6 Mar 2026 19:35:17 +0800 Subject: [PATCH] [bugfix] fix pass bug: pass really rope dim for npu_rotary_embedding (#6880) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? pass really rope dim for npu_rotary_embedding **before:** q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding( positions, q_flat, k_flat, cos_sin_cache, self.head_dim, **self.head_dim,** True ) **after:** q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding( positions, q_flat, k_flat, cos_sin_cache, self.head_dim, **self.rope_dim,** True ) ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: zjks98 Signed-off-by: aipaes <82140963+aipaes@users.noreply.github.com> Co-authored-by: zjks98 --- .../passes/qknorm_rope_fusion_pass.py | 7 +++++-- vllm_ascend/utils.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py index 31b0c6f4..b6ef5775 100644 --- a/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py +++ b/vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py @@ -24,6 +24,7 @@ from vllm.logger import logger from vllm.model_executor.layers.attention import Attention from vllm_ascend.compilation.passes.base_pattern import BasePattern +from vllm_ascend.utils import get_rope_dim class QKNormRopeFusionPattern(BasePattern): @@ -35,6 +36,7 @@ class QKNormRopeFusionPattern(BasePattern): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.device = vllm_config.device_config.device if vllm_config.device_config else None + self.rope_dim = get_rope_dim(vllm_config) def get_inputs(self): T = 5 @@ -65,7 +67,7 @@ class QKNormRopeFusionPattern(BasePattern): q_flat = q_norm_out.view(q.shape) k_flat = k_norm_out.view(k.shape) q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding( - positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True + positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.rope_dim, True ) return q_rope, k_rope, v @@ -108,6 +110,7 @@ class QKNormRopeFusionPatternWithBias(BasePattern): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.device = vllm_config.device_config.device if vllm_config.device_config else None + self.rope_dim = get_rope_dim(vllm_config) def get_inputs(self): T = 5 @@ -145,7 +148,7 @@ class QKNormRopeFusionPatternWithBias(BasePattern): q_flat = q_normed.view(q.shape) k_flat = k_normed.view(k.shape) q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding( - positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True + positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.rope_dim, True ) return q_rope, k_rope, v diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 1c52627a..aec8decf 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1185,3 +1185,19 @@ def check_gdn_layer(vllm_config) -> bool: return True return False + + +def get_rope_dim(vllm_config): + model_config = vllm_config.model_config + + if model_config.use_mla: + rope_dim = model_config.hf_text_config.qk_rope_head_dim + else: + rope_dim = model_config.get_head_size() + # For models using partial rope like Qwen3-Next. + if hasattr(model_config.hf_text_config, "partial_rotary_factor"): + rope_dim = int(rope_dim * model_config.hf_text_config.partial_rotary_factor) + elif hasattr(model_config.hf_text_config, "rotary_dim"): + rope_dim = int(model_config.hf_text_config.rotary_dim) + + return rope_dim