diff --git a/vllm_ascend/ops/fla.py b/vllm_ascend/ops/fla.py index b200c67..7903900 100644 --- a/vllm_ascend/ops/fla.py +++ b/vllm_ascend/ops/fla.py @@ -8,9 +8,89 @@ import torch import torch.nn.functional as F -import triton -from vllm.model_executor.layers.fla.ops.layernorm_guard import \ - layer_norm_fwd_kernel +from vllm.triton_utils import tl, triton + +MAX_CORES = 65535 + + +@triton.heuristics({ + "HAS_BIAS": lambda args: args["B"] is not None, + "HAS_Z": lambda args: args["Z"] is not None, +}) +@triton.jit +def layer_norm_fwd_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X_base + N, # number of columns in X_base + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + N_CORES: tl.constexpr, +): + # Map the program id to the row of X_base and Y_base it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + + BLOCK_ROWS = M if M < N_CORES else N_CORES + n_iters = M // BLOCK_ROWS + remain = M % BLOCK_ROWS + if row < remain: + n_iters = n_iters + 1 + + for i in tl.range(n_iters): + X_base = X + (i * BLOCK_ROWS * + stride_x_row) + row * stride_x_row + group * N + Y_base = Y + (i * BLOCK_ROWS * + stride_y_row) + row * stride_y_row + group * N + if HAS_Z: + Z_base = Z + (i * BLOCK_ROWS * + stride_z_row) + row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean_base = Mean + (i * BLOCK_ROWS) + group * M + Rstd_base = Rstd + (i * BLOCK_ROWS) + group * M + W_base = W + group * N + if HAS_BIAS: + B_base = B + group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X_base + cols, mask=cols < N, other=0.).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z_base + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean_base + row, mean) + xbar = tl.where(cols < N, x - mean, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd_base + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W_base + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B_base + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z_base + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y_base + cols, y, mask=mask) def _layer_norm_fwd( @@ -55,7 +135,7 @@ def _layer_norm_fwd( "This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) - grid = (M, ngroups) + grid = (M if M < MAX_CORES else MAX_CORES, ngroups) with torch.npu.device(x.device.index): layer_norm_fwd_kernel[grid]( x, @@ -74,6 +154,7 @@ def _layer_norm_fwd( BLOCK_N=BLOCK_N, NORM_BEFORE_GATE=norm_before_gate, IS_RMS_NORM=is_rms_norm, + N_CORES=MAX_CORES, num_warps=num_warps, ) return out, mean, rstd