[bugfix] fix pass bug: pass really rope dim for npu_rotary_embedding (#6880)

### What this PR does / why we need it?
pass really rope dim for npu_rotary_embedding
**before:**
            q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim,
**self.head_dim,** True
            )
**after:**
            q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim,
**self.rope_dim,** True
            )
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: zjks98 <zhangjiakang4@huawei.com>
Signed-off-by: aipaes <82140963+aipaes@users.noreply.github.com>
Co-authored-by: zjks98 <zhangjiakang4@huawei.com>
This commit is contained in:
aipaes
2026-03-06 19:35:17 +08:00
committed by GitHub
parent 094eb0eff9
commit 1c0ecf806a
2 changed files with 21 additions and 2 deletions

View File

@@ -24,6 +24,7 @@ from vllm.logger import logger
from vllm.model_executor.layers.attention import Attention
from vllm_ascend.compilation.passes.base_pattern import BasePattern
from vllm_ascend.utils import get_rope_dim
class QKNormRopeFusionPattern(BasePattern):
@@ -35,6 +36,7 @@ class QKNormRopeFusionPattern(BasePattern):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.device = vllm_config.device_config.device if vllm_config.device_config else None
self.rope_dim = get_rope_dim(vllm_config)
def get_inputs(self):
T = 5
@@ -65,7 +67,7 @@ class QKNormRopeFusionPattern(BasePattern):
q_flat = q_norm_out.view(q.shape)
k_flat = k_norm_out.view(k.shape)
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.rope_dim, True
)
return q_rope, k_rope, v
@@ -108,6 +110,7 @@ class QKNormRopeFusionPatternWithBias(BasePattern):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.device = vllm_config.device_config.device if vllm_config.device_config else None
self.rope_dim = get_rope_dim(vllm_config)
def get_inputs(self):
T = 5
@@ -145,7 +148,7 @@ class QKNormRopeFusionPatternWithBias(BasePattern):
q_flat = q_normed.view(q.shape)
k_flat = k_normed.view(k.shape)
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.rope_dim, True
)
return q_rope, k_rope, v

View File

@@ -1185,3 +1185,19 @@ def check_gdn_layer(vllm_config) -> bool:
return True
return False
def get_rope_dim(vllm_config):
model_config = vllm_config.model_config
if model_config.use_mla:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
else:
rope_dim = model_config.get_head_size()
# For models using partial rope like Qwen3-Next.
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
rope_dim = int(rope_dim * model_config.hf_text_config.partial_rotary_factor)
elif hasattr(model_config.hf_text_config, "rotary_dim"):
rope_dim = int(model_config.hf_text_config.rotary_dim)
return rope_dim