[bugfix] fix pass bug: pass really rope dim for npu_rotary_embedding (#6880)
### What this PR does / why we need it?
pass really rope dim for npu_rotary_embedding
**before:**
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim,
**self.head_dim,** True
)
**after:**
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim,
**self.rope_dim,** True
)
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: zjks98 <zhangjiakang4@huawei.com>
Signed-off-by: aipaes <82140963+aipaes@users.noreply.github.com>
Co-authored-by: zjks98 <zhangjiakang4@huawei.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from vllm.logger import logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
|
||||
from vllm_ascend.compilation.passes.base_pattern import BasePattern
|
||||
from vllm_ascend.utils import get_rope_dim
|
||||
|
||||
|
||||
class QKNormRopeFusionPattern(BasePattern):
|
||||
@@ -35,6 +36,7 @@ class QKNormRopeFusionPattern(BasePattern):
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.device = vllm_config.device_config.device if vllm_config.device_config else None
|
||||
self.rope_dim = get_rope_dim(vllm_config)
|
||||
|
||||
def get_inputs(self):
|
||||
T = 5
|
||||
@@ -65,7 +67,7 @@ class QKNormRopeFusionPattern(BasePattern):
|
||||
q_flat = q_norm_out.view(q.shape)
|
||||
k_flat = k_norm_out.view(k.shape)
|
||||
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
|
||||
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.rope_dim, True
|
||||
)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
@@ -108,6 +110,7 @@ class QKNormRopeFusionPatternWithBias(BasePattern):
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.device = vllm_config.device_config.device if vllm_config.device_config else None
|
||||
self.rope_dim = get_rope_dim(vllm_config)
|
||||
|
||||
def get_inputs(self):
|
||||
T = 5
|
||||
@@ -145,7 +148,7 @@ class QKNormRopeFusionPatternWithBias(BasePattern):
|
||||
q_flat = q_normed.view(q.shape)
|
||||
k_flat = k_normed.view(k.shape)
|
||||
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
|
||||
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.head_dim, True
|
||||
positions, q_flat, k_flat, cos_sin_cache, self.head_dim, self.rope_dim, True
|
||||
)
|
||||
|
||||
return q_rope, k_rope, v
|
||||
|
||||
@@ -1185,3 +1185,19 @@ def check_gdn_layer(vllm_config) -> bool:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_rope_dim(vllm_config):
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if model_config.use_mla:
|
||||
rope_dim = model_config.hf_text_config.qk_rope_head_dim
|
||||
else:
|
||||
rope_dim = model_config.get_head_size()
|
||||
# For models using partial rope like Qwen3-Next.
|
||||
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
|
||||
rope_dim = int(rope_dim * model_config.hf_text_config.partial_rotary_factor)
|
||||
elif hasattr(model_config.hf_text_config, "rotary_dim"):
|
||||
rope_dim = int(model_config.hf_text_config.rotary_dim)
|
||||
|
||||
return rope_dim
|
||||
|
||||
Reference in New Issue
Block a user