[v0.18.0][Feature] support qkv_rmsnorm_mrope for qwen3vl (#7852)
### What this PR does / why we need it? Qwen3vl full attention supports enabling the split_qkv_rmsnorm_mrope fusion operator. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - [x] Run Qwen3-VL dense model with the fusion operator, verify correct output - [x] Run Qwen3-VL MoE model with the fusion operator, verify correct output --------- Signed-off-by: jiangmengyu18 <451528648@qq.com> Signed-off-by: jiangmengyu18 <56633611+jiangmengyu18@users.noreply.github.com> Signed-off-by: betta18 <jiangmengyu1@huawei.com> Co-authored-by: betta18 <jiangmengyu1@huawei.com>
This commit is contained in:
44
vllm_ascend/patch/worker/patch_qwen3vl.py
Normal file
44
vllm_ascend/patch/worker/patch_qwen3vl.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from vllm.model_executor.models.qwen3 import Qwen3Attention
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeAttention
|
||||
|
||||
from vllm_ascend.ops.rotary_embedding import AscendMRotaryEmbedding
|
||||
|
||||
|
||||
def forward_with_split_qkv_rmsnorm_mrope(self, positions: torch.Tensor, hidden_states: torch.Tensor):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
if isinstance(self.rotary_emb, AscendMRotaryEmbedding):
|
||||
cos_sin = self.rotary_emb.cos_sin_cache[positions]
|
||||
if cos_sin.device != qkv.device:
|
||||
cos_sin = cos_sin.to(qkv.device)
|
||||
if cos_sin.dtype != qkv.dtype:
|
||||
cos_sin = cos_sin.to(qkv.dtype)
|
||||
q, k, v, _ = torch.ops.vllm.triton_split_qkv_rmsnorm_mrope(
|
||||
qkv=qkv,
|
||||
q_weight=self.q_norm.weight,
|
||||
k_weight=self.k_norm.weight,
|
||||
cos_sin=cos_sin,
|
||||
num_q_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_dim,
|
||||
eps=self.q_norm.variance_epsilon,
|
||||
mrope_section=self.rotary_emb.mrope_section,
|
||||
is_interleaved=self.rotary_emb.mrope_interleaved,
|
||||
rope_dim=self.rotary_emb.rotary_dim,
|
||||
)
|
||||
else:
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim)
|
||||
q_by_head = self.q_norm(q_by_head)
|
||||
q = q_by_head.view(q.shape)
|
||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
|
||||
k_by_head = self.k_norm(k_by_head)
|
||||
k = k_by_head.view(k.shape)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
Qwen3Attention.forward = forward_with_split_qkv_rmsnorm_mrope
|
||||
Qwen3MoeAttention.forward = forward_with_split_qkv_rmsnorm_mrope
|
||||
Reference in New Issue
Block a user