[features]support split qkv rmsnorm rmope for qwen3.5 (#7368)
### What this PR does / why we need it?
Qwen3.5 full attention supports enabling the split_qkv_rmsnorm_mrope
fusion operator.
### How was this patch tested?
vLLM version: v0.16.0
vLLM-Ascend main: https://github.com/vllm-project/vllm-ascend/pull/6730
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: ZhuQi-seu <zhuqi12@huawei.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update
|
||||
from vllm.model_executor.models.qwen3_5 import Qwen3_5GatedDeltaNet
|
||||
from vllm.model_executor.models.qwen3_next import Qwen3NextAttention
|
||||
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
@@ -260,4 +261,54 @@ class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet):
|
||||
maybe_save_kv_layer_to_connector("", [])
|
||||
|
||||
|
||||
class AscendQwen3NextAttention(Qwen3NextAttention):
|
||||
def forward(self, positions: torch.Tensor, output: torch.Tensor, hidden_states: torch.Tensor):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
if "qwen3_5" in self.config.model_type:
|
||||
cos_sin = self.rotary_emb.cos_sin_cache[positions]
|
||||
if cos_sin.device != qkv.device:
|
||||
cos_sin = cos_sin.to(qkv.device)
|
||||
if cos_sin.dtype != qkv.dtype:
|
||||
cos_sin = cos_sin.to(qkv.dtype)
|
||||
|
||||
q, k, v, gate = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope(
|
||||
qkv=qkv,
|
||||
q_weight=1.0 + self.q_norm.weight,
|
||||
k_weight=1.0 + self.k_norm.weight,
|
||||
cos_sin=cos_sin,
|
||||
num_q_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_dim,
|
||||
eps=self.config.rms_norm_eps,
|
||||
mrope_section=self.rotary_emb.mrope_section,
|
||||
is_interleaved=self.rotary_emb.mrope_interleaved,
|
||||
rope_dim=self.rotary_emb.rotary_dim,
|
||||
has_gate=self.attn_output_gate,
|
||||
)
|
||||
else:
|
||||
if self.attn_output_gate:
|
||||
q_gate, k, v = qkv.split([self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
|
||||
orig_shape = q_gate.shape[:-1]
|
||||
q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
|
||||
q, gate = torch.chunk(q_gate, 2, dim=-1)
|
||||
q = q.reshape(*orig_shape, -1)
|
||||
gate = gate.reshape(*orig_shape, -1)
|
||||
else:
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)).view(-1, self.num_heads * self.head_dim)
|
||||
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)).view(-1, self.num_kv_heads * self.head_dim)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
if self.attn_output_gate:
|
||||
gate = torch.sigmoid(gate)
|
||||
attn_output = attn_output * gate
|
||||
|
||||
output[:], _ = self.o_proj(attn_output)
|
||||
|
||||
|
||||
Qwen3_5GatedDeltaNet._forward_core = AscendQwen3_5GatedDeltaNet._forward_core
|
||||
Qwen3NextAttention.forward = AscendQwen3NextAttention.forward
|
||||
|
||||
Reference in New Issue
Block a user