First commit
This commit is contained in:
158
pkgs/xformers/ops/triton/rmsnorm_kernels.py
Normal file
158
pkgs/xformers/ops/triton/rmsnorm_kernels.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
if hasattr(tl, "libdevice"):
|
||||
tl_math = tl.libdevice
|
||||
else:
|
||||
tl_math = tl.math
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_kernel(
|
||||
x_ptr,
|
||||
h1_ptr,
|
||||
w_ptr,
|
||||
eps,
|
||||
stride,
|
||||
N_COLS,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
INCLUDE_WEIGHT: tl.constexpr,
|
||||
):
|
||||
row = tl.program_id(0)
|
||||
x_ptr += row * stride
|
||||
h1_ptr += row * stride
|
||||
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for offset in range(0, N_COLS, BLOCK_SIZE):
|
||||
cols = offset + tl.arange(0, BLOCK_SIZE)
|
||||
a = tl.load(
|
||||
x_ptr + cols, mask=cols < N_COLS, other=0.0, eviction_policy="evict_last"
|
||||
).to(tl.float32)
|
||||
_mean += a * a
|
||||
rstd = tl_math.rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
|
||||
for offset in range(0, N_COLS, BLOCK_SIZE):
|
||||
cols = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N_COLS
|
||||
a = tl.load(
|
||||
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
|
||||
).to(tl.float32)
|
||||
if INCLUDE_WEIGHT:
|
||||
w = tl.load(w_ptr + cols, mask=mask)
|
||||
tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
|
||||
else:
|
||||
tl.store(h1_ptr + cols, a * rstd, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rms_norm_add_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
h1_ptr,
|
||||
w_ptr,
|
||||
eps,
|
||||
stride,
|
||||
N_COLS,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
INCLUDE_WEIGHT: tl.constexpr,
|
||||
):
|
||||
row = tl.program_id(0)
|
||||
x_ptr += row * stride
|
||||
y_ptr += row * stride
|
||||
h1_ptr += row * stride
|
||||
|
||||
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
||||
for offset in range(0, N_COLS, BLOCK_SIZE):
|
||||
cols = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N_COLS
|
||||
ax = tl.load(
|
||||
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_last"
|
||||
).to(tl.float32)
|
||||
ay = tl.load(
|
||||
y_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
|
||||
).to(tl.float32)
|
||||
a = ax + ay
|
||||
tl.store(x_ptr + cols, a, mask=mask)
|
||||
_mean += a * a
|
||||
rstd = tl_math.rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps)
|
||||
for offset in range(0, N_COLS, BLOCK_SIZE):
|
||||
cols = offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = cols < N_COLS
|
||||
a = tl.load(
|
||||
x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first"
|
||||
).to(tl.float32)
|
||||
if INCLUDE_WEIGHT:
|
||||
w = tl.load(w_ptr + cols, mask=mask)
|
||||
tl.store(h1_ptr + cols, a * rstd * w, mask=mask)
|
||||
else:
|
||||
tl.store(h1_ptr + cols, a * rstd, mask=mask)
|
||||
|
||||
|
||||
def _rms_norm_forward(x, attn_norm_weights, eps):
|
||||
if not x.is_contiguous():
|
||||
raise ValueError("data must be contiguous")
|
||||
if attn_norm_weights is not None:
|
||||
if not attn_norm_weights.is_contiguous():
|
||||
raise ValueError("weights must be contiguous")
|
||||
out = torch.empty_like(x)
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
||||
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
_rms_norm_kernel[(M,)](
|
||||
x_arg,
|
||||
out,
|
||||
attn_norm_weights,
|
||||
eps,
|
||||
x_arg.stride(0),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
INCLUDE_WEIGHT=attn_norm_weights is not None,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _rms_norm_add_forward(x, y, attn_norm_weights, eps):
|
||||
# x, y contiguous of same shape [..., n]
|
||||
# output of same shape, normed over the last dim.
|
||||
if not x.is_contiguous():
|
||||
raise ValueError("x must be contiguous")
|
||||
if not y.is_contiguous():
|
||||
raise ValueError("y must be contiguous")
|
||||
if attn_norm_weights is not None:
|
||||
if not attn_norm_weights.is_contiguous():
|
||||
raise ValueError("weights must be contiguous")
|
||||
out = torch.empty_like(x)
|
||||
x_arg = x.reshape(-1, x.shape[-1])
|
||||
y_arg = y.reshape(-1, x.shape[-1])
|
||||
M, N = x_arg.shape
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
||||
BLOCK_SIZE = max(BLOCK_SIZE, 128)
|
||||
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
|
||||
_rms_norm_add_kernel[(M,)](
|
||||
x_arg,
|
||||
y_arg,
|
||||
out,
|
||||
attn_norm_weights,
|
||||
eps,
|
||||
x_arg.stride(0),
|
||||
N,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
INCLUDE_WEIGHT=attn_norm_weights is not None,
|
||||
)
|
||||
return out
|
||||
Reference in New Issue
Block a user