From e942b62d742ebc5bf128e85bc086d728df8d4935 Mon Sep 17 00:00:00 2001 From: ZhuQi-seu Date: Mon, 23 Mar 2026 23:58:12 +0800 Subject: [PATCH] [features]support split qkv rmsnorm rmope for qwen3.5 (#7368) ### What this PR does / why we need it? Qwen3.5 full attention supports enabling the split_qkv_rmsnorm_mrope fusion operator. ### How was this patch tested? vLLM version: v0.16.0 vLLM-Ascend main: https://github.com/vllm-project/vllm-ascend/pull/6730 - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: ZhuQi-seu --- vllm_ascend/patch/worker/patch_qwen3_5.py | 51 +++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/vllm_ascend/patch/worker/patch_qwen3_5.py b/vllm_ascend/patch/worker/patch_qwen3_5.py index 748438fb..4bbe5ead 100644 --- a/vllm_ascend/patch/worker/patch_qwen3_5.py +++ b/vllm_ascend/patch/worker/patch_qwen3_5.py @@ -22,6 +22,7 @@ from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update from vllm.model_executor.models.qwen3_5 import Qwen3_5GatedDeltaNet +from vllm.model_executor.models.qwen3_next import Qwen3NextAttention from vllm.v1.attention.backend import AttentionMetadata # type: ignore from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata from vllm.v1.attention.backends.utils import PAD_SLOT_ID @@ -260,4 +261,54 @@ class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet): maybe_save_kv_layer_to_connector("", []) +class AscendQwen3NextAttention(Qwen3NextAttention): + def forward(self, positions: torch.Tensor, output: torch.Tensor, hidden_states: torch.Tensor): + qkv, _ = self.qkv_proj(hidden_states) + if "qwen3_5" in self.config.model_type: + cos_sin = self.rotary_emb.cos_sin_cache[positions] + if cos_sin.device != qkv.device: + cos_sin = cos_sin.to(qkv.device) + if cos_sin.dtype != qkv.dtype: + cos_sin = cos_sin.to(qkv.dtype) + + q, k, v, gate = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope( + qkv=qkv, + q_weight=1.0 + self.q_norm.weight, + k_weight=1.0 + self.k_norm.weight, + cos_sin=cos_sin, + num_q_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + head_size=self.head_dim, + eps=self.config.rms_norm_eps, + mrope_section=self.rotary_emb.mrope_section, + is_interleaved=self.rotary_emb.mrope_interleaved, + rope_dim=self.rotary_emb.rotary_dim, + has_gate=self.attn_output_gate, + ) + else: + if self.attn_output_gate: + q_gate, k, v = qkv.split([self.q_size * 2, self.kv_size, self.kv_size], dim=-1) + orig_shape = q_gate.shape[:-1] + q_gate = q_gate.view(*orig_shape, self.num_heads, -1) + q, gate = torch.chunk(q_gate, 2, dim=-1) + q = q.reshape(*orig_shape, -1) + gate = gate.reshape(*orig_shape, -1) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(-1, self.num_heads * self.head_dim) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(-1, self.num_kv_heads * self.head_dim) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v) + + if self.attn_output_gate: + gate = torch.sigmoid(gate) + attn_output = attn_output * gate + + output[:], _ = self.o_proj(attn_output) + + Qwen3_5GatedDeltaNet._forward_core = AscendQwen3_5GatedDeltaNet._forward_core +Qwen3NextAttention.forward = AscendQwen3NextAttention.forward