import torch from vllm.model_executor.models.qwen3 import Qwen3Attention from vllm.model_executor.models.qwen3_moe import Qwen3MoeAttention from vllm_ascend.ops.rotary_embedding import AscendMRotaryEmbedding def forward_with_split_qkv_rmsnorm_mrope(self, positions: torch.Tensor, hidden_states: torch.Tensor): qkv, _ = self.qkv_proj(hidden_states) if isinstance(self.rotary_emb, AscendMRotaryEmbedding): cos_sin = self.rotary_emb.cos_sin_cache[positions] if cos_sin.device != qkv.device: cos_sin = cos_sin.to(qkv.device) if cos_sin.dtype != qkv.dtype: cos_sin = cos_sin.to(qkv.dtype) q, k, v, _ = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope( qkv=qkv, q_weight=self.q_norm.weight, k_weight=self.k_norm.weight, cos_sin=cos_sin, num_q_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_size=self.head_dim, eps=self.q_norm.variance_epsilon, mrope_section=self.rotary_emb.mrope_section, is_interleaved=self.rotary_emb.mrope_interleaved, rope_dim=self.rotary_emb.rotary_dim, ) else: q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) q_by_head = self.q_norm(q_by_head) q = q_by_head.view(q.shape) k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output Qwen3Attention.forward = forward_with_split_qkv_rmsnorm_mrope Qwen3MoeAttention.forward = forward_with_split_qkv_rmsnorm_mrope