From 4230bc86461e4b879e0b434eaa11cfbd2bf83ef7 Mon Sep 17 00:00:00 2001 From: wubin58 <1290313138@qq.com> Date: Fri, 30 Jan 2026 21:25:04 +0800 Subject: [PATCH] =?UTF-8?q?[Bugfix]Modify=20NPU=20rotary=20encoding=20para?= =?UTF-8?q?meter=20fields=EF=BC=8Cfix=20RopeOperation=20setup=20failed=20i?= =?UTF-8?q?n=20condition=20of=20self.rotary=5Fdim=20<=20self.head=5Fsize?= =?UTF-8?q?=20(#6310)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? change self.head_size to self.rotary_dim. only the rotary part is processed here, the dimension should be rotary_dim. Fix bug #6060 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Only a small section of code was modified to adjust the parameters, and all standard tests were passed. - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd Signed-off-by: fengshi666 Co-authored-by: fengshi666 --- vllm_ascend/ops/rotary_embedding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 2f507b74..0c19b3e9 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -240,11 +240,13 @@ def _rope_forward_oot( k_pass = key[..., self.rotary_dim:] q_rot = q_rot.contiguous().view(num_tokens, -1) k_rot = k_rot.contiguous().view(num_tokens, -1) + # only the rotary part is processed here, + # the dimension should be rotary_dim torch_npu._npu_rotary_embedding( positions, q_rot, k_rot, - self.head_size, + self.rotary_dim, self.cos_sin_cache, is_neox_style, )