[BugFix] Add support for rotary_dim parameter when using partial rope in rotary_embedding (#6581)
### What this PR does / why we need it?
Issue: If a model such as Ling-1T adopts partial rotary position
embedding (partial RoPE), but config.json uses the rotary_dim parameter
instead of partial_rotary_factor, it will trigger a RuntimeError: The
expanded size of the tensor (128) must match the existing size (64) at
non-singleton dimension 3.
<img width="1681" height="472" alt="image"
src="https://github.com/user-attachments/assets/ba03d7df-ecba-4d6f-9ec1-4dc55f59799e"
/>
This PR addresses an issue where models using partial rotary position
embedding (partial RoPE) with the `rotary_dim` parameter in
`config.json` (instead of `partial_rotary_factor`) would encounter a
`RuntimeError`.
This change adds support for the `rotary_dim` parameter in
`vllm_ascend/ops/rotary_embedding.py` to correctly calculate the
`rope_dim`, resolving the tensor size mismatch error.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
The patch was tested successfully with the Ling-1T model, which
previously triggered the error.
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
Signed-off-by: GoCHug <93277779+GoCHug@users.noreply.github.com>
This commit is contained in:
@@ -76,6 +76,8 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, devi
|
||||
# For models using partial rope like Qwen3-Next.
|
||||
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):
|
||||
rope_dim = int(rope_dim * model_config.hf_text_config.partial_rotary_factor)
|
||||
elif hasattr(model_config.hf_text_config, "rotary_dim"):
|
||||
rope_dim = int(model_config.hf_text_config.rotary_dim)
|
||||
_cos = torch.ones(1, max_num_batched_tokens, 1, rope_dim, dtype=dtype, device=device)
|
||||
_sin = torch.zeros(1, max_num_batched_tokens, 1, rope_dim, dtype=dtype, device=device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user