[features]support split qkv rmsnorm rmope for qwen3.5 (#7368)

### What this PR does / why we need it?
Qwen3.5 full attention supports enabling the split_qkv_rmsnorm_mrope
fusion operator.

### How was this patch tested?
vLLM version: v0.16.0
vLLM-Ascend main: https://github.com/vllm-project/vllm-ascend/pull/6730
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: ZhuQi-seu <zhuqi12@huawei.com>
This commit is contained in:
ZhuQi-seu
2026-03-23 23:58:12 +08:00
committed by GitHub
parent 8e0789bb36
commit e942b62d74

View File

@@ -22,6 +22,7 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update
from vllm.model_executor.models.qwen3_5 import Qwen3_5GatedDeltaNet
from vllm.model_executor.models.qwen3_next import Qwen3NextAttention
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
@@ -260,4 +261,54 @@ class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet):
maybe_save_kv_layer_to_connector("", [])
class AscendQwen3NextAttention(Qwen3NextAttention):
def forward(self, positions: torch.Tensor, output: torch.Tensor, hidden_states: torch.Tensor):
qkv, _ = self.qkv_proj(hidden_states)
if "qwen3_5" in self.config.model_type:
cos_sin = self.rotary_emb.cos_sin_cache[positions]
if cos_sin.device != qkv.device:
cos_sin = cos_sin.to(qkv.device)
if cos_sin.dtype != qkv.dtype:
cos_sin = cos_sin.to(qkv.dtype)
q, k, v, gate = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope(
qkv=qkv,
q_weight=1.0 + self.q_norm.weight,
k_weight=1.0 + self.k_norm.weight,
cos_sin=cos_sin,
num_q_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_dim,
eps=self.config.rms_norm_eps,
mrope_section=self.rotary_emb.mrope_section,
is_interleaved=self.rotary_emb.mrope_interleaved,
rope_dim=self.rotary_emb.rotary_dim,
has_gate=self.attn_output_gate,
)
else:
if self.attn_output_gate:
q_gate, k, v = qkv.split([self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
orig_shape = q_gate.shape[:-1]
q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
q, gate = torch.chunk(q_gate, 2, dim=-1)
q = q.reshape(*orig_shape, -1)
gate = gate.reshape(*orig_shape, -1)
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(-1, self.num_heads * self.head_dim)
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(-1, self.num_kv_heads * self.head_dim)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if self.attn_output_gate:
gate = torch.sigmoid(gate)
attn_output = attn_output * gate
output[:], _ = self.o_proj(attn_output)
Qwen3_5GatedDeltaNet._forward_core = AscendQwen3_5GatedDeltaNet._forward_core
Qwen3NextAttention.forward = AscendQwen3NextAttention.forward