From 55b20ac63b950d0096c0a84a8aaf9f1e1343e2b4 Mon Sep 17 00:00:00 2001 From: LeeWenquan <83354342+SunnyLee151064@users.noreply.github.com> Date: Tue, 20 Jan 2026 14:43:14 +0800 Subject: [PATCH] [Ops] Add layernorm for qwen3Next (#5765) ### What this PR does / why we need it? Add layernormFn triton op for qwen3Next model for better performance. image ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: SunnyLee219 <3294305115@qq.com> --- .github/workflows/nightly_test_a2.yaml | 2 +- vllm_ascend/ops/layernorm.py | 82 ++++++++++- vllm_ascend/ops/triton/layernorm_gated.py | 171 ++++++++++++++++++++++ vllm_ascend/utils.py | 3 +- 4 files changed, 254 insertions(+), 4 deletions(-) create mode 100644 vllm_ascend/ops/triton/layernorm_gated.py diff --git a/.github/workflows/nightly_test_a2.yaml b/.github/workflows/nightly_test_a2.yaml index aac1592c..736fcb27 100644 --- a/.github/workflows/nightly_test_a2.yaml +++ b/.github/workflows/nightly_test_a2.yaml @@ -54,7 +54,7 @@ jobs: tests: tests/e2e/nightly/single_node/models/test_qwen3_8b.py - name: qwen3next os: linux-aarch64-a2-4 - ests: tests/e2e/nightly/single_node/models/test_qwen3_next.py + tests: tests/e2e/nightly/single_node/models/test_qwen3_next.py - name: qwen3-32b os: linux-aarch64-a2-4 tests: tests/e2e/nightly/single_node/models/test_qwen3_32b.py diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 98c0e6bb..5866bc6e 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -18,9 +18,10 @@ from typing import Optional, Tuple, Union import torch +from torch import nn from vllm.config import get_current_vllm_config -from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm - +from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm, RMSNormGated +from vllm_ascend.ops.triton.layernorm_gated import layer_norm_fwd_npu class AscendRMSNorm(RMSNorm): @@ -95,3 +96,80 @@ class AscendGemmaRMSNorm(GemmaRMSNorm): x, _ = torch_npu.npu_rms_norm(x, 1.0 + self.weight, self.variance_epsilon) return x + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward(ctx, + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = layer_norm_fwd_npu( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + ctx.save_for_backward(x, weight, bias, mean, rstd, z) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.group_size = group_size + ctx.norm_before_gate = norm_before_gate + ctx.is_rms_norm = is_rms_norm + return y.reshape(x_shape_og) + +class AscendRMSNormGated(RMSNormGated): + + def __init__( + self, + hidden_size, + eps: float = 1e-5, + group_size: Optional[int] = None, + norm_before_gate: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """If group_size is not None, we do GroupNorm with each group having group_size elements. + group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(hidden_size, eps, group_size, norm_before_gate, device, dtype) + self.eps = eps + self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter("bias", None) + self.group_size = group_size + self.norm_before_gate = norm_before_gate + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.ones_(self.weight) + + def forward_oot(self, x, z=None): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z)) + """ + return LayerNormFn.apply(x, self.weight, self.bias, z, self.eps, self.group_size, + self.norm_before_gate, True) \ No newline at end of file diff --git a/vllm_ascend/ops/triton/layernorm_gated.py b/vllm_ascend/ops/triton/layernorm_gated.py new file mode 100644 index 00000000..48ca46f5 --- /dev/null +++ b/vllm_ascend/ops/triton/layernorm_gated.py @@ -0,0 +1,171 @@ +# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py +# Copyright (c) 2024, Tri Dao. +# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. +# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. +# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. +# mypy: ignore-errors + +import torch +from vllm.triton_utils import tl, triton + +@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel_npu( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + pid_m = tl.program_id(0) + group = tl.program_id(1) + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + + # Compute row indices for this program + rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.arange(0, BLOCK_N) + + # Mask for valid rows and cols + row_mask = rows < M + col_mask = cols < N + + # Load weight once (broadcasted over rows) + w = tl.load(W + cols, mask=col_mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=col_mask).to(tl.float32) + + # Load X: shape [BLOCK_M, BLOCK_N] + x_ptrs = X + rows[:, None] * stride_x_row + cols[None, :] + group * N + x = tl.load(x_ptrs, mask=row_mask[:, None] & col_mask[None, :]).to(tl.float32) + + # Load Z if needed + if HAS_Z: + z_ptrs = Z + rows[:, None] * stride_z_row + cols[None, :] + group * N + z = tl.load(z_ptrs, mask=row_mask[:, None] & col_mask[None, :]).to(tl.float32) + if not NORM_BEFORE_GATE: + x *= z * tl.sigmoid(z) + + # Compute statistics per row + if not IS_RMS_NORM: + mean = tl.sum(x, axis=1) / N # [BLOCK_M] + xbar = tl.where(col_mask[None, :], x - mean[:, None], 0.0) + var = tl.sum(xbar * xbar, axis=1) / N + tl.store(Mean + rows, mean, mask=row_mask) + else: + xbar = tl.where(col_mask[None, :], x, 0.0) + var = tl.sum(xbar * xbar, axis=1) / N + + rstd = 1.0 / tl.sqrt(var + eps) # [BLOCK_M] + tl.store(Rstd + rows, rstd, mask=row_mask) + + # Normalize + if not IS_RMS_NORM: + x_hat = (x - mean[:, None]) * rstd[:, None] + else: + x_hat = x * rstd[:, None] + + y = x_hat * w[None, :] + if HAS_BIAS: + y += b[None, :] + + # Post-gate + if HAS_Z and NORM_BEFORE_GATE: + y *= z * tl.sigmoid(z) + + # Store output + y_ptrs = Y + rows[:, None] * stride_y_row + cols[None, :] + group * N + tl.store(y_ptrs, y, mask=row_mask[:, None] & col_mask[None, :]) + + +def layer_norm_fwd_npu( + x, + weight, + bias, + eps, + z=None, + out=None, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = ( + torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + if not is_rms_norm + else None + ) + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("Feature dim too large.") + + # Choose BLOCK_M: e.g., 16, 32, 64 — depends on NPU vector core capacity + BLOCK_M = 64 # Tune this based on your NPU's register/shared memory + + # Now grid is (num blocks over M, num groups) + grid = (triton.cdiv(M, BLOCK_M), ngroups) + _layer_norm_fwd_1pass_kernel_npu[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + # Remove multibuffer if not needed + ) + return out, mean, rstd \ No newline at end of file diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 7ac39131..86e05841 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -669,7 +669,7 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE, AscendSharedFusedMoE - from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm + from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm, AscendRMSNormGated from vllm_ascend.ops.linear import ( AscendColumnParallelLinear, AscendMergedColumnParallelLinear, @@ -715,6 +715,7 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): "MultiHeadLatentAttentionWrapper": AscendMultiHeadLatentAttention, "MMEncoderAttention": AscendMMEncoderAttention, "ApplyRotaryEmb": AscendApplyRotaryEmb, + "RMSNormGated": AscendRMSNormGated, } # 310P: override selected ops with 310P implementations (keep minimal changes outside _310p)