[bugfix] fix pass bug: pass really rope dim for npu_rotary_embedding (#6880)
### What this PR does / why we need it?
pass really rope dim for npu_rotary_embedding
**before:**
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim,
**self.head_dim,** True
)
**after:**
q_rope, k_rope = torch.ops.vllm.npu_rotary_embedding(
positions, q_flat, k_flat, cos_sin_cache, self.head_dim,
**self.rope_dim,** True
)
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: zjks98 <zhangjiakang4@huawei.com>
Signed-off-by: aipaes <82140963+aipaes@users.noreply.github.com>
Co-authored-by: zjks98 <zhangjiakang4@huawei.com>
This commit is contained in:
@@ -1185,3 +1185,19 @@ def check_gdn_layer(vllm_config) -> bool:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_rope_dim(vllm_config):
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if model_config.use_mla:
|
||||
rope_dim = model_config.hf_text_config.qk_rope_head_dim
|
||||
else:
|
||||
rope_dim = model_config.get_head_size()
|
||||
# For models using partial rope like Qwen3-Next.
|
||||
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
|
||||
rope_dim = int(rope_dim * model_config.hf_text_config.partial_rotary_factor)
|
||||
elif hasattr(model_config.hf_text_config, "rotary_dim"):
|
||||
rope_dim = int(model_config.hf_text_config.rotary_dim)
|
||||
|
||||
return rope_dim
|
||||
|
||||
Reference in New Issue
Block a user