From 9bba0a2a689dc6586996c0f4f8ee0bf7c65e88db Mon Sep 17 00:00:00 2001 From: zhangxinyuehfad <59153331+zhangxinyuehfad@users.noreply.github.com> Date: Thu, 22 Jan 2026 15:46:05 +0800 Subject: [PATCH] [Bugfix] Fix Triton operator usage for multimodal models based on `the mrope_interleaved` parameter (#6042) ### What this PR does / why we need it? When running the Qwen2.5-Omni-7B model on Ascend NPU, the engine fails during the profiling/warmup stage with the following error: `AclNN_Runtime_Error(EZ9903): rtKernelLaunchWithHandleV2 failed: 507035. The vector core execution is abnormal.` error log: https://github.com/vllm-project/vllm-ascend/actions/runs/21144534911/job/60806765393#step:17:6412 This error is specifically triggered by the `triton_mrope` kernel when handling the unique `mrope_section` configurations of the Omni model. Other multimodal models with standard sections (e.g., [16, 24, 24]) or standard LLMs work correctly with Triton. Modified vllm_ascend/ops/rotary_embedding.py to add a conditional check before calling forward_triton. 1. For standard LLMs (mrope_interleaved = True ), it continues to use Triton for acceleration. 2. For complex configurations (like Qwen2.5-Omni mrope_interleaved = False ), it now falls back to the native super().forward_oot() path, which uses the stable torch_npu or PyTorch implementation. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: hfadzxy --- vllm_ascend/ops/rotary_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index afc01d55..04c9d302 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -586,7 +586,7 @@ class AscendMRotaryEmbedding(MRotaryEmbedding): query: torch.Tensor, key: torch.Tensor, ): - if HAS_TRITON and positions.ndim == 2: + if HAS_TRITON and positions.ndim == 2 and self.mrope_interleaved: # todo: need cann update in 8.5.0 return self.forward_triton(positions, query, key)