First commit
This commit is contained in:
179
pkgs/xformers/triton/k_layer_norm.py
Normal file
179
pkgs/xformers/triton/k_layer_norm.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# CREDITS: This comes almost as-is from the Triton layer norm tutorial
|
||||
# https://github.com/openai/triton/blob/master/python/tutorials/05-layer-norm.py
|
||||
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# fmt: off
|
||||
@triton.jit
|
||||
def layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, affine: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
|
||||
# fmt: on
|
||||
"""
|
||||
Fused layernorm kernel over a 3d tensor.
|
||||
The layer norm is applied over the last dimension.
|
||||
|
||||
Compute
|
||||
y = (x - E(x))/(sqrt(var(x) + epsilon)) * gamma + beta
|
||||
"""
|
||||
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < N
|
||||
|
||||
# Move to this row
|
||||
x_ptrs = X + row * stride + cols
|
||||
x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
# Compute mean and variance
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
x_zm = tl.where(mask, x - mean, 0.0)
|
||||
tl.store(M + row, mean)
|
||||
|
||||
x_var = tl.sum(x_zm * x_zm, axis=0) / N
|
||||
rstd = 1.0 / tl.sqrt(x_var + eps)
|
||||
|
||||
# Normalize, optionally affine
|
||||
y = x_zm * rstd
|
||||
tl.store(V + row, rstd)
|
||||
|
||||
mask = cols < N
|
||||
if affine:
|
||||
w = tl.load(W + cols, mask=mask, other=1.0)
|
||||
b = tl.load(B + cols, mask=mask, other=0.0)
|
||||
y = y * w + b
|
||||
|
||||
y_ptrs = Y + row * stride + cols
|
||||
tl.store(y_ptrs, y, mask=mask)
|
||||
|
||||
|
||||
# Backward pass (DX + partial DW + partial DB)
|
||||
# fmt: off
|
||||
@triton.jit
|
||||
def layer_norm_bwd_dx_fused(
|
||||
DX, DY, DW, DB,
|
||||
X, W, M, V,
|
||||
Lock, stride, N,
|
||||
# META-parameters
|
||||
affine: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
# fmt: on
|
||||
|
||||
# position of elements processed by this program
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < N
|
||||
|
||||
# offset data pointers to start at the row of interest
|
||||
x_ptrs = X + row * stride + cols
|
||||
dy_ptrs = DY + row * stride + cols
|
||||
|
||||
# load data to SRAM
|
||||
x = tl.load(x_ptrs, mask=mask, other=0)
|
||||
dy = tl.load(dy_ptrs, mask=mask, other=0)
|
||||
mean = tl.load(M + row)
|
||||
rstd = tl.load(V + row)
|
||||
|
||||
# compute dx
|
||||
xhat = (x - mean) * rstd
|
||||
|
||||
if affine:
|
||||
w = tl.load(W + cols, mask=mask, other=0)
|
||||
wdy = w * dy
|
||||
else:
|
||||
wdy = dy
|
||||
|
||||
xhat = tl.where(mask, xhat, 0.)
|
||||
wdy = tl.where(mask, wdy, 0.)
|
||||
mean1 = tl.sum(xhat * wdy, axis=0) / N
|
||||
mean2 = tl.sum(wdy, axis=0) / N
|
||||
dx = (wdy - (xhat * mean1 + mean2)) * rstd
|
||||
|
||||
# write-back dx
|
||||
cols = tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = cols < N # re-materialize the mask to save registers
|
||||
dx_ptrs = DX + row * stride + cols
|
||||
tl.store(dx_ptrs, dx, mask=mask)
|
||||
|
||||
if affine:
|
||||
# accumulate partial sums for dw/db
|
||||
partial_dw = (dy * xhat).to(w.dtype)
|
||||
partial_db = dy.to(w.dtype)
|
||||
|
||||
# offset locks and weight/bias gradient pointer
|
||||
# each kernel instance accumulates partial sums for
|
||||
# DW and DB into one of GROUP_SIZE_M independent buffers
|
||||
# these buffers stay in the L2, which allow this kernel
|
||||
# to be fast
|
||||
lock_id = row % GROUP_SIZE_M
|
||||
Lock += lock_id
|
||||
Count = Lock + GROUP_SIZE_M
|
||||
|
||||
# - wait for a lock on the accumulated dw/db
|
||||
while tl.atomic_cas(Lock, 0, 1) == 1:
|
||||
pass
|
||||
count = tl.load(Count)
|
||||
|
||||
# - we got the lock, accumulate this kernel's results with
|
||||
# the stored values.
|
||||
dw_ptrs = DW + lock_id * N + cols
|
||||
db_ptrs = DB + lock_id * N + cols
|
||||
|
||||
if count == 0:
|
||||
# first store doesn't accumulate
|
||||
tl.atomic_xchg(Count, 1)
|
||||
else:
|
||||
partial_dw += tl.load(dw_ptrs, mask=mask, other=0.)
|
||||
partial_db += tl.load(db_ptrs, mask=mask, other=0.)
|
||||
|
||||
tl.store(dw_ptrs, partial_dw, mask=mask)
|
||||
tl.store(db_ptrs, partial_db, mask=mask)
|
||||
|
||||
# release lock
|
||||
tl.atomic_xchg(Lock, 0)
|
||||
|
||||
|
||||
# Backward pass (total DW + total DB)
|
||||
# fmt: off
|
||||
@triton.jit
|
||||
def layer_norm_bwd_dwdb(
|
||||
DW, DB, FINAL_DW, FINAL_DB,
|
||||
M, N,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr
|
||||
):
|
||||
# fmt: on
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask_cols = cols < N
|
||||
|
||||
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for i in range(0, M, BLOCK_SIZE_M):
|
||||
rows = i + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs = rows[:, None] * N + cols[None, :]
|
||||
mask_rm = rows < M
|
||||
|
||||
dw += tl.load(DW + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0)
|
||||
db += tl.load(DB + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0)
|
||||
|
||||
sum_dw = tl.sum(dw, axis=0)
|
||||
sum_db = tl.sum(db, axis=0)
|
||||
|
||||
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask_cols = cols < N
|
||||
|
||||
tl.store(FINAL_DW + cols, sum_dw, mask=mask_cols)
|
||||
tl.store(FINAL_DB + cols, sum_db, mask=mask_cols)
|
||||
Reference in New Issue
Block a user