[BugFix]Fix wrong _cos, _sin instantiation (#5154)

### What this PR does / why we need it?
This PR add additional check on creating global `_cos` and `_sin`, avoid
creating them when using `mrope` or encoder-decoder model.

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: Angazenn <supperccell@163.com>
This commit is contained in:
Angazenn
2025-12-20 22:52:50 +08:00
committed by GitHub
parent 5d02eed16f
commit 67a0325cf2
2 changed files with 33 additions and 18 deletions

View File

@@ -27,7 +27,7 @@ from vllm.model_executor.layers.rotary_embedding import (
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (AscendDeviceType, enable_custom_op,
get_ascend_device_type, is_vl_model)
get_ascend_device_type, has_rope, is_vl_model)
# Currently, rope ops used on npu requires detached cos && sin as inputs.
# However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable.
@@ -64,21 +64,22 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype,
model_config = vllm_config.model_config
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
elif not is_vl_model(vllm_config) and not vllm_config.model_config.use_mla:
if model_config.use_mla:
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
rope_dim = model_config.hf_text_config.qk_rope_head_dim
_cos_mla = torch.ones(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
_sin_mla = torch.zeros(max_num_reqs * decode_token_per_req,
1,
1,
rope_dim,
dtype=dtype,
device=device)
elif not is_vl_model(vllm_config) and has_rope(vllm_config):
rope_dim = model_config.get_head_size()
# For models using partial rope like Qwen3-Next.
if hasattr(model_config.hf_text_config, "partial_rotary_factor"):