From 85234d096dd558980c6f9731e4398997b0af97da Mon Sep 17 00:00:00 2001 From: jiangmengyu18 <56633611+jiangmengyu18@users.noreply.github.com> Date: Thu, 2 Apr 2026 17:46:50 +0800 Subject: [PATCH] [v0.18.0][Feature] support qkv_rmsnorm_mrope for qwen3vl (#7852) ### What this PR does / why we need it? Qwen3vl full attention supports enabling the split_qkv_rmsnorm_mrope fusion operator. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - [x] Run Qwen3-VL dense model with the fusion operator, verify correct output - [x] Run Qwen3-VL MoE model with the fusion operator, verify correct output --------- Signed-off-by: jiangmengyu18 <451528648@qq.com> Signed-off-by: jiangmengyu18 <56633611+jiangmengyu18@users.noreply.github.com> Signed-off-by: betta18 Co-authored-by: betta18 --- vllm_ascend/patch/__init__.py | 21 ++++++++++- vllm_ascend/patch/worker/__init__.py | 1 + vllm_ascend/patch/worker/patch_qwen3vl.py | 44 +++++++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 vllm_ascend/patch/worker/patch_qwen3vl.py diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 40abccca..c8da9365 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -701,4 +701,23 @@ # Let vLLM support triton ops dispatch. # Future Plan: # Remove this patch when vLLM support the dispatch function. -# +# ** 27. File: worker/patch_qwen3vl.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.qwen3.Qwen3Attention.forward` and +# `vllm.model_executor.models.qwen3_moe.Qwen3MoeAttention.forward` +# Why: +# support triton_split_qkv_rmsnorm_mrope fused kernel for Qwen3Attention and Qwen3MoeAttention. +# How: +# override forward method with the triton_split_qkv_rmsnorm_mrope fused kernel, +# when using mrope. +# Future Plan: +# Remove this patch when vllm-ascend supports pattern matching for this fused kernel. +# ** 28. File: worker/patch_qwen3vl.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.qwen3_vl.Qwen3VLForConditionalGeneration._get_deepstack_input_embeds` +# Why: +# support flash comm v1 for qwen3vl. +# How: +# override _get_deepstack_input_embeds method with the flash comm v1 implementation. +# Future Plan: +# Remove this patch when https://github.com/vllm-project/vllm-ascend/issues/5712 is completed. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 4d2dbfd7..b79ad7c0 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -49,4 +49,5 @@ import vllm_ascend.patch.worker.patch_deepseek_mtp # noqa import vllm_ascend.patch.worker.patch_v2.patch_input_batch # noqa import vllm_ascend.patch.worker.patch_v2.patch_model_state # noqa import vllm_ascend.patch.worker.patch_v2.patch_block_table # noqa +import vllm_ascend.patch.worker.patch_qwen3vl # noqa import vllm_ascend.patch.worker.patch_deepencoder2 # noqa diff --git a/vllm_ascend/patch/worker/patch_qwen3vl.py b/vllm_ascend/patch/worker/patch_qwen3vl.py new file mode 100644 index 00000000..514de099 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_qwen3vl.py @@ -0,0 +1,44 @@ +import torch +from vllm.model_executor.models.qwen3 import Qwen3Attention +from vllm.model_executor.models.qwen3_moe import Qwen3MoeAttention + +from vllm_ascend.ops.rotary_embedding import AscendMRotaryEmbedding + + +def forward_with_split_qkv_rmsnorm_mrope(self, positions: torch.Tensor, hidden_states: torch.Tensor): + qkv, _ = self.qkv_proj(hidden_states) + if isinstance(self.rotary_emb, AscendMRotaryEmbedding): + cos_sin = self.rotary_emb.cos_sin_cache[positions] + if cos_sin.device != qkv.device: + cos_sin = cos_sin.to(qkv.device) + if cos_sin.dtype != qkv.dtype: + cos_sin = cos_sin.to(qkv.dtype) + q, k, v, _ = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope( + qkv=qkv, + q_weight=self.q_norm.weight, + k_weight=self.k_norm.weight, + cos_sin=cos_sin, + num_q_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_dim, + eps=self.q_norm.variance_epsilon, + mrope_section=self.rotary_emb.mrope_section, + is_interleaved=self.rotary_emb.mrope_interleaved, + rope_dim=self.rotary_emb.rotary_dim, + ) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) + q_by_head = self.q_norm(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) + k_by_head = self.k_norm(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +Qwen3Attention.forward = forward_with_split_qkv_rmsnorm_mrope +Qwen3MoeAttention.forward = forward_with_split_qkv_rmsnorm_mrope