First commit
This commit is contained in:
239
pkgs/triton/ops/blocksparse/softmax.py
Normal file
239
pkgs/triton/ops/blocksparse/softmax.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def num_warps(n):
|
||||
if n <= 128:
|
||||
return 1
|
||||
if n <= 256:
|
||||
return 2
|
||||
if n <= 512:
|
||||
return 4
|
||||
if n <= 4096:
|
||||
return 8
|
||||
return 16
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _blocksparse_softmax_fwd(
|
||||
Out, A, stride_xz, LUT,
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
# create index ranges
|
||||
hm = h * tl.num_programs(1) + m
|
||||
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
|
||||
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
|
||||
# extract information from LUT
|
||||
header = LUT + (hm // BLOCK_SIZE) * 2
|
||||
size = tl.load(header + 0)
|
||||
offset = tl.load(header + 1)
|
||||
# pointer offset
|
||||
off_a = z * stride_xz
|
||||
off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx
|
||||
off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx
|
||||
# do not need to read column indices in the dense case
|
||||
if IS_DENSE:
|
||||
ns = tl.arange(0, ROW_SIZE)
|
||||
else:
|
||||
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
|
||||
start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
|
||||
ns = start_n * BLOCK_SIZE + lane_n
|
||||
# load X
|
||||
mask = block_n < size
|
||||
a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf"))
|
||||
a = a.to(tl.float32)
|
||||
# compute
|
||||
out = a
|
||||
out *= scale
|
||||
# apply relative attention
|
||||
if R is not None:
|
||||
R += z * stride_zr
|
||||
R += h * stride_hr
|
||||
off_lo = (extent - m - 1) + ns
|
||||
mask_lo = (off_lo >= 0) & (off_lo < extent)
|
||||
rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)
|
||||
out += rel_logits
|
||||
out = out.to(tl.float32)
|
||||
# apply causal mask
|
||||
out = tl.where((ns > m) & is_causal, -float("inf"), out)
|
||||
# computation
|
||||
out = tl.softmax(out)
|
||||
# write-back
|
||||
tl.store(Out + off_a + lane_n, out, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _blocksparse_softmax_bwd(
|
||||
DA, stride_zdx,
|
||||
DOut, stride_zdout,
|
||||
Out, stride_zout,
|
||||
scale,
|
||||
LUT,
|
||||
DR, extent, stride_zr, stride_hr, stride_er,
|
||||
is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
# create index ranges
|
||||
hm = h * tl.num_programs(1) + m
|
||||
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
|
||||
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
|
||||
# extract information from LUT
|
||||
header = LUT + (hm // BLOCK_SIZE) * 2
|
||||
size = tl.load(header + 0)
|
||||
offset = tl.load(header + 1)
|
||||
# row-col offset
|
||||
off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE
|
||||
off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE
|
||||
mask = block_n < size
|
||||
# pointers
|
||||
As = Out + z * stride_zout + off_mn
|
||||
DOuts = DOut + z * stride_zdout + off_mn
|
||||
# do not need to read column indices in the dense case
|
||||
if IS_DENSE:
|
||||
ns = tl.arange(0, ROW_SIZE)
|
||||
else:
|
||||
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
|
||||
start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)
|
||||
ns = start_n * BLOCK_SIZE + lane_n
|
||||
# load data
|
||||
a = tl.load(As + lane_n, mask=mask, other=0.0)
|
||||
a = a.to(tl.float32)
|
||||
dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)
|
||||
dout = dout.to(tl.float32)
|
||||
# compute
|
||||
a = tl.where((ns > m) & is_causal & (a == a), 0., a)
|
||||
da = a * (dout - tl.sum(a * dout, 0))
|
||||
# apply relative attention
|
||||
if DR is not None:
|
||||
DR += z * stride_zr
|
||||
DR += h * stride_hr
|
||||
off_lo = (extent - m - 1) + ns
|
||||
mask_lo = (off_lo >= 0) & (off_lo < extent) & mask
|
||||
tl.store(DR + m * extent + off_lo, da, mask=mask_lo)
|
||||
da = da * scale
|
||||
# convert da
|
||||
# write-back
|
||||
DAs = DA + z * stride_zdx + off_mn
|
||||
tl.store(DAs + lane_n, da, mask=mask)
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def make_lut(layout, block, device):
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
sizes = _empty.clone()
|
||||
# sizes along rows
|
||||
for h in range(layout.shape[0]):
|
||||
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
|
||||
total_sizes = sizes * block
|
||||
# offsets in block format
|
||||
offsets = torch.zeros_like(sizes)
|
||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||
# block indices
|
||||
columns = layout.nonzero(as_tuple=False)[:, 2]
|
||||
header = torch.stack((sizes, offsets), dim=1).view(-1)
|
||||
lut = torch.cat((header, columns)).type(torch.int32).to(device)
|
||||
return lut, int(total_sizes.max())
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, scale, rel_logits, is_causal,
|
||||
spdims, block, lut, maxlut, is_dense
|
||||
):
|
||||
if scale is not None and isinstance(scale, torch.Tensor):
|
||||
assert scale.device.type == "cpu"
|
||||
scale = scale.item()
|
||||
M = a.shape[0]
|
||||
grid = [spdims[0], spdims[1] * block, M]
|
||||
rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape
|
||||
rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()
|
||||
# enqueue kernel
|
||||
out = torch.empty_like(a)
|
||||
_blocksparse_softmax_fwd[grid](
|
||||
out, a, a.stride(0), lut,
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||
scale,
|
||||
is_causal,
|
||||
BLOCK_SIZE=block,
|
||||
ROW_SIZE=triton.next_power_of_2(maxlut),
|
||||
IS_DENSE=is_dense,
|
||||
num_warps=num_warps(maxlut)
|
||||
)
|
||||
# save to context
|
||||
# ctx.mark_dirty(x)
|
||||
ctx.save_for_backward(out, lut)
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.maxlut = maxlut
|
||||
ctx.scale = scale
|
||||
ctx.rel_shape = rel_shape
|
||||
ctx.rel_strides = rel_strides
|
||||
ctx.rel_dtype = a.dtype
|
||||
ctx.is_dense = is_dense
|
||||
ctx.is_causal = is_causal
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
# retrieve from context
|
||||
out, lut = ctx.saved_tensors
|
||||
# relative logits gradients
|
||||
dr = None
|
||||
if ctx.needs_input_grad[3]:
|
||||
dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)
|
||||
# run kernel
|
||||
M = out.shape[0]
|
||||
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
|
||||
da = torch.empty_like(dout)
|
||||
_blocksparse_softmax_bwd[grid](
|
||||
da, da.stride(0),
|
||||
dout, dout.stride(0),
|
||||
out, out.stride(0),
|
||||
ctx.scale,
|
||||
lut,
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
|
||||
ctx.is_causal,
|
||||
BLOCK_SIZE=ctx.block,
|
||||
ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
|
||||
IS_DENSE=ctx.is_dense,
|
||||
num_warps=num_warps(ctx.maxlut)
|
||||
)
|
||||
return (da, None, None, dr, None,
|
||||
None, None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None
|
||||
)
|
||||
|
||||
|
||||
class softmax:
|
||||
def __init__(self, layout, block, device, is_dense=False):
|
||||
self.spdims = layout.shape
|
||||
self.layout = layout
|
||||
self.block = block
|
||||
self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)
|
||||
self.is_dense = is_dense
|
||||
|
||||
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
|
||||
if rel_logits is not None and rel_logits.dtype != a.dtype:
|
||||
raise ValueError(f"relative position embedding must be {a.dtype}")
|
||||
a = _softmax.apply(
|
||||
a, scale, rel_logits, is_causal,
|
||||
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
|
||||
)
|
||||
return a
|
||||
Reference in New Issue
Block a user