Files
xc-llm-ascend/vllm_ascend/ops/triton/layernorm_gated.py
Mr.WXS dba34d4915 [v0.18.0][Triton][Qwen3.5] delete expr for kernels args (#7646)
### What this PR does / why we need it?
Some parameters of Triton operators are unnecessarily modified with the
"constexpr" modifier. When these parameters change, recompilation is
triggered, which significantly affects the model performance. Therefore,
these parameters need to be rectified.
backport: https://github.com/vllm-project/vllm-ascend/pull/7482


Signed-off-by: w30012745 <wangxiaoshuai2@h-partners.com>
Co-authored-by: w30012745 <wangxiaoshuai2@h-partners.com>
2026-03-25 23:31:27 +08:00

169 lines
5.3 KiB
Python

# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
# Copyright (c) 2024, Tri Dao.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
# mypy: ignore-errors
import torch
from vllm.triton_utils import tl, triton
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@triton.jit(do_not_specialize=["stride_x_row", "stride_y_row", "stride_z_row", "M", "N", "eps"])
def _layer_norm_fwd_1pass_kernel_npu(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Z, # pointer to the other branch
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_z_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
pid_m = tl.program_id(0)
group = tl.program_id(1)
if not IS_RMS_NORM:
Mean += group * M
Rstd += group * M
W += group * N
if HAS_BIAS:
B += group * N
# Compute row indices for this program
rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
cols = tl.arange(0, BLOCK_N)
# Mask for valid rows and cols
row_mask = rows < M
col_mask = cols < N
# Load weight once (broadcasted over rows)
w = tl.load(W + cols, mask=col_mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=col_mask).to(tl.float32)
# Load X: shape [BLOCK_M, BLOCK_N]
x_ptrs = X + rows[:, None] * stride_x_row + cols[None, :] + group * N
x = tl.load(x_ptrs, mask=row_mask[:, None] & col_mask[None, :]).to(tl.float32)
# Load Z if needed
if HAS_Z:
z_ptrs = Z + rows[:, None] * stride_z_row + cols[None, :] + group * N
z = tl.load(z_ptrs, mask=row_mask[:, None] & col_mask[None, :]).to(tl.float32)
if not NORM_BEFORE_GATE:
x *= z * tl.sigmoid(z)
# Compute statistics per row
if not IS_RMS_NORM:
mean = tl.sum(x, axis=1) / N # [BLOCK_M]
xbar = tl.where(col_mask[None, :], x - mean[:, None], 0.0)
var = tl.sum(xbar * xbar, axis=1) / N
tl.store(Mean + rows, mean, mask=row_mask)
else:
xbar = tl.where(col_mask[None, :], x, 0.0)
var = tl.sum(xbar * xbar, axis=1) / N
rstd = 1.0 / tl.sqrt(var + eps) # [BLOCK_M]
tl.store(Rstd + rows, rstd, mask=row_mask)
# Normalize
if not IS_RMS_NORM:
x_hat = (x - mean[:, None]) * rstd[:, None]
else:
x_hat = x * rstd[:, None]
y = x_hat * w[None, :]
if HAS_BIAS:
y += b[None, :]
# Post-gate
if HAS_Z and NORM_BEFORE_GATE:
y *= z * tl.sigmoid(z)
# Store output
y_ptrs = Y + rows[:, None] * stride_y_row + cols[None, :] + group * N
tl.store(y_ptrs, y, mask=row_mask[:, None] & col_mask[None, :])
def layer_norm_fwd_npu(
x,
weight,
bias,
eps,
z=None,
out=None,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
M, N = x.shape
if group_size is None:
group_size = N
assert N % group_size == 0
ngroups = N // group_size
assert x.stride(-1) == 1
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError("Feature dim too large.")
# Choose BLOCK_M: e.g., 16, 32, 64 — depends on NPU vector core capacity
BLOCK_M = 64 # Tune this based on your NPU's register/shared memory
# Now grid is (num blocks over M, num groups)
grid = (triton.cdiv(M, BLOCK_M), ngroups)
_layer_norm_fwd_1pass_kernel_npu[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
# Remove multibuffer if not needed
)
return out, mean, rstd