[Ops] Add layernorm for qwen3Next (#5765)

### What this PR does / why we need it?
Add layernormFn triton op for qwen3Next model for better performance.

<img width="248" height="526" alt="image"
src="https://github.com/user-attachments/assets/27b47157-5df5-4db1-aa88-1dae799b2bf6"
/>

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

---------

Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
LeeWenquan
2026-01-20 14:43:14 +08:00
committed by GitHub
parent 0664c6e67a
commit 55b20ac63b
4 changed files with 254 additions and 4 deletions

View File

@@ -54,7 +54,7 @@ jobs:
tests: tests/e2e/nightly/single_node/models/test_qwen3_8b.py
- name: qwen3next
os: linux-aarch64-a2-4
ests: tests/e2e/nightly/single_node/models/test_qwen3_next.py
tests: tests/e2e/nightly/single_node/models/test_qwen3_next.py
- name: qwen3-32b
os: linux-aarch64-a2-4
tests: tests/e2e/nightly/single_node/models/test_qwen3_32b.py

View File

@@ -18,9 +18,10 @@
from typing import Optional, Tuple, Union
import torch
from torch import nn
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm, RMSNormGated
from vllm_ascend.ops.triton.layernorm_gated import layer_norm_fwd_npu
class AscendRMSNorm(RMSNorm):
@@ -95,3 +96,80 @@ class AscendGemmaRMSNorm(GemmaRMSNorm):
x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight,
self.variance_epsilon)
return x
class LayerNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = layer_norm_fwd_npu(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.group_size = group_size
ctx.norm_before_gate = norm_before_gate
ctx.is_rms_norm = is_rms_norm
return y.reshape(x_shape_og)
class AscendRMSNormGated(RMSNormGated):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
group_size: Optional[int] = None,
norm_before_gate: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(hidden_size, eps, group_size, norm_before_gate, device, dtype)
self.eps = eps
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
self.norm_before_gate = norm_before_gate
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
def forward_oot(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return LayerNormFn.apply(x, self.weight, self.bias, z, self.eps, self.group_size,
self.norm_before_gate, True)

View File

@@ -0,0 +1,171 @@
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
# Copyright (c) 2024, Tri Dao.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
# mypy: ignore-errors
import torch
from vllm.triton_utils import tl, triton
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel_npu(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Z, # pointer to the other branch
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_z_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
pid_m = tl.program_id(0)
group = tl.program_id(1)
if not IS_RMS_NORM:
Mean += group * M
Rstd += group * M
W += group * N
if HAS_BIAS:
B += group * N
# Compute row indices for this program
rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
cols = tl.arange(0, BLOCK_N)
# Mask for valid rows and cols
row_mask = rows < M
col_mask = cols < N
# Load weight once (broadcasted over rows)
w = tl.load(W + cols, mask=col_mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=col_mask).to(tl.float32)
# Load X: shape [BLOCK_M, BLOCK_N]
x_ptrs = X + rows[:, None] * stride_x_row + cols[None, :] + group * N
x = tl.load(x_ptrs, mask=row_mask[:, None] & col_mask[None, :]).to(tl.float32)
# Load Z if needed
if HAS_Z:
z_ptrs = Z + rows[:, None] * stride_z_row + cols[None, :] + group * N
z = tl.load(z_ptrs, mask=row_mask[:, None] & col_mask[None, :]).to(tl.float32)
if not NORM_BEFORE_GATE:
x *= z * tl.sigmoid(z)
# Compute statistics per row
if not IS_RMS_NORM:
mean = tl.sum(x, axis=1) / N # [BLOCK_M]
xbar = tl.where(col_mask[None, :], x - mean[:, None], 0.0)
var = tl.sum(xbar * xbar, axis=1) / N
tl.store(Mean + rows, mean, mask=row_mask)
else:
xbar = tl.where(col_mask[None, :], x, 0.0)
var = tl.sum(xbar * xbar, axis=1) / N
rstd = 1.0 / tl.sqrt(var + eps) # [BLOCK_M]
tl.store(Rstd + rows, rstd, mask=row_mask)
# Normalize
if not IS_RMS_NORM:
x_hat = (x - mean[:, None]) * rstd[:, None]
else:
x_hat = x * rstd[:, None]
y = x_hat * w[None, :]
if HAS_BIAS:
y += b[None, :]
# Post-gate
if HAS_Z and NORM_BEFORE_GATE:
y *= z * tl.sigmoid(z)
# Store output
y_ptrs = Y + rows[:, None] * stride_y_row + cols[None, :] + group * N
tl.store(y_ptrs, y, mask=row_mask[:, None] & col_mask[None, :])
def layer_norm_fwd_npu(
x,
weight,
bias,
eps,
z=None,
out=None,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
M, N = x.shape
if group_size is None:
group_size = N
assert N % group_size == 0
ngroups = N // group_size
assert x.stride(-1) == 1
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = (
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError("Feature dim too large.")
# Choose BLOCK_M: e.g., 16, 32, 64 — depends on NPU vector core capacity
BLOCK_M = 64 # Tune this based on your NPU's register/shared memory
# Now grid is (num blocks over M, num groups)
grid = (triton.cdiv(M, BLOCK_M), ngroups)
_layer_norm_fwd_1pass_kernel_npu[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
# Remove multibuffer if not needed
)
return out, mean, rstd

View File

@@ -669,7 +669,7 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE, AscendSharedFusedMoE
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm, AscendRMSNormGated
from vllm_ascend.ops.linear import (
AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
@@ -715,6 +715,7 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
"MultiHeadLatentAttentionWrapper": AscendMultiHeadLatentAttention,
"MMEncoderAttention": AscendMMEncoderAttention,
"ApplyRotaryEmb": AscendApplyRotaryEmb,
"RMSNormGated": AscendRMSNormGated,
}
# 310P: override selected ops with 310P implementations (keep minimal changes outside _310p)