[Ops] Add layernorm for qwen3Next (#5765)

### What this PR does / why we need it?
Add layernormFn triton op for qwen3Next model for better performance.

<img width="248" height="526" alt="image"
src="https://github.com/user-attachments/assets/27b47157-5df5-4db1-aa88-1dae799b2bf6"
/>

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
LeeWenquan
2026-01-20 14:43:14 +08:00
committed by GitHub
parent 0664c6e67a
commit 55b20ac63b
4 changed files with 254 additions and 4 deletions

View File

@@ -18,9 +18,10 @@
from typing import Optional, Tuple, Union
import torch
from torch import nn
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm, RMSNormGated
from vllm_ascend.ops.triton.layernorm_gated import layer_norm_fwd_npu
class AscendRMSNorm(RMSNorm):
@@ -95,3 +96,80 @@ class AscendGemmaRMSNorm(GemmaRMSNorm):
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
self.variance_epsilon)
return x
class LayerNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = layer_norm_fwd_npu(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.group_size = group_size
ctx.norm_before_gate = norm_before_gate
ctx.is_rms_norm = is_rms_norm
return y.reshape(x_shape_og)
class AscendRMSNormGated(RMSNormGated):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
group_size: Optional[int] = None,
norm_before_gate: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(hidden_size, eps, group_size, norm_before_gate, device, dtype)
self.eps = eps
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
self.norm_before_gate = norm_before_gate
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
def forward_oot(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return LayerNormFn.apply(x, self.weight, self.bias, z, self.eps, self.group_size,
self.norm_before_gate, True)