diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 566c48f4..628b0dd2 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -27,7 +27,7 @@ from vllm.model_executor.layers.rotary_embedding import ( from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import (AscendDeviceType, enable_custom_op, - get_ascend_device_type, is_vl_model) + get_ascend_device_type, has_rope, is_vl_model) # Currently, rope ops used on npu requires detached cos && sin as inputs. # However, RotaryEmbedding in vllm use cos_sin_cache as a whole variable. @@ -64,21 +64,22 @@ def set_cos_and_sin(vllm_config, max_num_reqs, decode_token_per_req, dtype, model_config = vllm_config.model_config max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens - if model_config.use_mla and compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: - rope_dim = model_config.hf_text_config.qk_rope_head_dim - _cos_mla = torch.ones(max_num_reqs * decode_token_per_req, - 1, - 1, - rope_dim, - dtype=dtype, - device=device) - _sin_mla = torch.zeros(max_num_reqs * decode_token_per_req, - 1, - 1, - rope_dim, - dtype=dtype, - device=device) - elif not is_vl_model(vllm_config) and not vllm_config.model_config.use_mla: + if model_config.use_mla: + if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: + rope_dim = model_config.hf_text_config.qk_rope_head_dim + _cos_mla = torch.ones(max_num_reqs * decode_token_per_req, + 1, + 1, + rope_dim, + dtype=dtype, + device=device) + _sin_mla = torch.zeros(max_num_reqs * decode_token_per_req, + 1, + 1, + rope_dim, + dtype=dtype, + device=device) + elif not is_vl_model(vllm_config) and has_rope(vllm_config): rope_dim = model_config.get_head_size() # For models using partial rope like Qwen3-Next. if hasattr(model_config.hf_text_config, "partial_rotary_factor"): diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index e7b3b8ec..d6737178 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -64,6 +64,7 @@ _HAS_LAYER_IDX = None _SUBSCRIBED_COMPUTE_STREAMS = set() _GRAPH_PRINT_STREAM = None _GRAPH_PRINT_STREAM_LOCK = Lock() +_HAS_ROPE = None def _print_callback_on_stream(*args): @@ -823,11 +824,24 @@ def is_vl_model(vllm_config: VllmConfig): """Checks if the model is a VL model by config""" global _IS_VL_MODEL if _IS_VL_MODEL is None and vllm_config and vllm_config.model_config: - model_configs = vllm_config.model_config.hf_config.to_dict() - _IS_VL_MODEL = "VL" in model_configs["architectures"][0] + hf_config = vllm_config.model_config.hf_config.to_dict() + if "thinker_config" in hf_config: + # Qwen-Omni-thinker models + _IS_VL_MODEL = True + else: + _IS_VL_MODEL = "vision_config" in hf_config return _IS_VL_MODEL +def has_rope(vllm_config: VllmConfig): + """Checks if the model uses rope.""" + global _HAS_ROPE + if _HAS_ROPE is None and vllm_config and vllm_config.model_config: + hf_config = vllm_config.model_config.hf_config.to_dict() + _HAS_ROPE = "rope_parameters" in hf_config + return _HAS_ROPE + + def weak_ref_tensor(tensor: Any) -> Any: """ Create a weak reference to a tensor.