# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py # Copyright (c) 2024, Tri Dao. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate. # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. # mypy: ignore-errors import torch from vllm.triton_utils import tl, triton @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) @triton.jit def _layer_norm_fwd_1pass_kernel_npu( X, # pointer to the input Y, # pointer to the output W, # pointer to the weights B, # pointer to the biases Z, # pointer to the other branch Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_z_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_Z: tl.constexpr, NORM_BEFORE_GATE: tl.constexpr, IS_RMS_NORM: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. pid_m = tl.program_id(0) group = tl.program_id(1) if not IS_RMS_NORM: Mean += group * M Rstd += group * M W += group * N if HAS_BIAS: B += group * N # Compute row indices for this program rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) cols = tl.arange(0, BLOCK_N) # Mask for valid rows and cols row_mask = rows < M col_mask = cols < N # Load weight once (broadcasted over rows) w = tl.load(W + cols, mask=col_mask).to(tl.float32) if HAS_BIAS: b = tl.load(B + cols, mask=col_mask).to(tl.float32) # Load X: shape [BLOCK_M, BLOCK_N] x_ptrs = X + rows[:, None] * stride_x_row + cols[None, :] + group * N x = tl.load(x_ptrs, mask=row_mask[:, None] & col_mask[None, :]).to(tl.float32) # Load Z if needed if HAS_Z: z_ptrs = Z + rows[:, None] * stride_z_row + cols[None, :] + group * N z = tl.load(z_ptrs, mask=row_mask[:, None] & col_mask[None, :]).to(tl.float32) if not NORM_BEFORE_GATE: x *= z * tl.sigmoid(z) # Compute statistics per row if not IS_RMS_NORM: mean = tl.sum(x, axis=1) / N # [BLOCK_M] xbar = tl.where(col_mask[None, :], x - mean[:, None], 0.0) var = tl.sum(xbar * xbar, axis=1) / N tl.store(Mean + rows, mean, mask=row_mask) else: xbar = tl.where(col_mask[None, :], x, 0.0) var = tl.sum(xbar * xbar, axis=1) / N rstd = 1.0 / tl.sqrt(var + eps) # [BLOCK_M] tl.store(Rstd + rows, rstd, mask=row_mask) # Normalize if not IS_RMS_NORM: x_hat = (x - mean[:, None]) * rstd[:, None] else: x_hat = x * rstd[:, None] y = x_hat * w[None, :] if HAS_BIAS: y += b[None, :] # Post-gate if HAS_Z and NORM_BEFORE_GATE: y *= z * tl.sigmoid(z) # Store output y_ptrs = Y + rows[:, None] * stride_y_row + cols[None, :] + group * N tl.store(y_ptrs, y, mask=row_mask[:, None] & col_mask[None, :]) def layer_norm_fwd_npu( x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False, ): M, N = x.shape if group_size is None: group_size = N assert N % group_size == 0 ngroups = N // group_size assert x.stride(-1) == 1 if z is not None: assert z.stride(-1) == 1 assert z.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) # allocate output if out is not None: assert out.shape == x.shape else: out = torch.empty_like(x) assert out.stride(-1) == 1 mean = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) if group_size > BLOCK_N: raise RuntimeError("Feature dim too large.") # Choose BLOCK_M: e.g., 16, 32, 64 — depends on NPU vector core capacity BLOCK_M = 64 # Tune this based on your NPU's register/shared memory # Now grid is (num blocks over M, num groups) grid = (triton.cdiv(M, BLOCK_M), ngroups) _layer_norm_fwd_1pass_kernel_npu[grid]( x, out, weight, bias, z, mean, rstd, x.stride(0), out.stride(0), z.stride(0) if z is not None else 0, M, group_size, eps, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, NORM_BEFORE_GATE=norm_before_gate, IS_RMS_NORM=is_rms_norm, # Remove multibuffer if not needed ) return out, mean, rstd