First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
from .matmul import matmul
from .softmax import softmax
__all__ = [
"matmul",
"softmax",
]

View File

@@ -0,0 +1,437 @@
import torch
import triton
import triton.language as tl
# ********************************************************
# --------------------------------------------------------
# Sparse = Dense x Dense (SDD)
# This operation uses super-blocking to make sure that
# it's done efficiently when small blocks can be grouped
# together
# --------------------------------------------------------
# ********************************************************
@triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
})
@triton.jit
def _sdd_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
):
# ------------ #
# - Prologue - #
# ------------ #
block_id = tl.program_id(0) + grid_offset
lut += block_id * 3
# offsets
off_z = tl.program_id(2) # batch
off_h = tl.load(lut + 0) # head
# initialize pointers to A
start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K)
a_ptrs = A \
+ off_z * stride_za \
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B
start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K)
b_ptrs = B \
+ off_z * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_nb \
+ offs_bk[:, None] * stride_bk
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(K, 0, -TILE_K):
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
acc += tl.dot(a, b, out_dtype=tl.float32)
a_ptrs += TILE_K * stride_ak
b_ptrs += TILE_K * stride_bk
c = acc.to(C.dtype.element_ty)
# ---------------- #
# Epilogue #
# ---------------- #
offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C \
+ off_z * stride_zc \
+ block_id * stride_hc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
tl.store(pc, c, mask=True)
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# (A * B)^T = B^T * A^T
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
# shape constraints
a_dim = -2 if trans_a else -1
b_dim = -1 if trans_b else -2
Ka, Kb = a.shape[a_dim], b.shape[b_dim]
if Ka != Kb:
raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
# allocate output
if out is None:
c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)
else:
assert out.shape == (a.shape[0], lut.shape[0], block, block)
c = out
grid = [c.shape[1], 1, c.shape[0]]
_sdd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, 0, lut,
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
num_warps=4,
)
return c
def sdd_lut(layout, block, device):
lut = layout.nonzero(as_tuple=False).to(device).int()
lut = lut.contiguous()
return lut, None
# -----------------------------
# Dense = Sparse x Dense (DSD)
# This operation uses a look-up table that contains pre-computed pointer increments
# in order to minimize computations in the inner loop of the matmul kernel.
# -----------------------------
@triton.jit
def _dsd_kernel(
A, B, C,
stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
):
# ------------ #
# - Prologue - #
# ------------ #
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pidz = tl.program_id(2)
header = lut + pid_n * 4
offset = tl.load(header + 0)
K = tl.load(header + 1)
column = tl.load(header + 2)
off_h = tl.load(header + 3)
pinc = lut + offset
# initialize pointers to A (sparse)
block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8) # compiler hint
offs_am = tl.arange(0, TILE_M)
offs_ak = tl.arange(0, TILE_K)
pa = A + pidz * stride_az \
+ block_id * stride_ha \
+ offs_am[:, None] * stride_am \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B (dense)
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
start_bk = tl.load(pinc)
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
offs_bk = start_bk + tl.arange(0, TILE_K)
pb = B + pidz * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
for k in range(K, 0, -TILE_K):
a = tl.load(pa)
b = tl.load(pb)
acc += tl.dot(a, b, out_dtype=tl.float32)
pa += inc_a
pb += inc_b * stride_bk
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
c = acc.to(C.dtype.element_ty)
# initialize pointers to C
offs_cm = column * TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
pc = C \
+ off_h * stride_hc \
+ pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# shapes / dtypes
AS1 = block * spdims[2 if trans_a else 1]
BS0 = b.size(0)
BS1 = b.size(1)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# allocate output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
if out is None:
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
else:
assert out.shape == (CS0, CS1, CS2, CS3)
c = out
# meta-parameter heuristics
TILE_N = 128
# compute output
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
_dsd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut,
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
)
# exit()
return c
def dsd_lut(layout, block, step, trans, device):
"""
Generates the look-up table for incrementing pointers in the DSD/DDS matmul.
Example (BLOCK=32, STEP=16)
[[1, 0, 0, 1, 0],
[0, 1, 1, 0, 1],
[1, 0, 1, 0, 0]]
Then the offsets for A are
[0 , 16, 32, 48] <- row 0
\\----/ \\----/
col=0 col=3
[64, 80, 96, 112, 128, 144] <- row 1
\\----/ \\----/ \\------/
col=1 col=2 col=3
[160, 176, 192, 208]
which leads to increments table
[0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16]
Because B is dense, the offsets are
[0, 16, 96, 112] <- row 0
[32, 48, 64, 80] <- row 1
[0, 16, 64, 80] <- row 2
"""
sizes = torch.sum(layout, 2 if trans else 1)
head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True)
sizes = sizes.flatten()
segments = sizes * step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
# -------------------------------
# dense input pointer increments
# -------------------------------
# Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
# that is smaller than the block size, so we need to do a bit of extra work
# to handle this case
B_idx = nnz[:, 2] * block
B_incs = B_idx.clone()
B_incs[1:] -= B_idx[:-1]
div = block // step
B_incs = B_incs.view(-1, 1).repeat(1, div)
B_incs[:, 1:] = step
B_incs[:, 0] -= (div - 1) * step
# first increment for each reduction is actually the offset
B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]]
B_incs = B_incs.view(-1)
# -------------------------------
# sparse input pointer increments
# -------------------------------
# same as above, except that the increments are in the sparse memory layout
if trans:
A_idx = torch.arange(num_blocks, device=layout.device)
else:
A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
current_offset = 0
for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone().long()
msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device)
A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
A_incs = A_idx * block * block
A_incs[1:] -= A_idx[:-1] * block * block
A_incs = A_incs.view(-1, 1).repeat(1, div)
if trans:
A_incs[:, 1:] = step
A_incs[:, 0] -= (div - 1) * step
else:
A_incs[:, 1:] = step * block
A_incs[:, 0] -= (div - 1) * step * block
A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]]
A_incs = A_incs.view(-1)
# create header
width = col_id.size(0)
offsets = offsets * 2 * div + 4 * width
segments = segments * div
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
# create increments
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
# pad by a factor 2*MAX_NUM_STAGES
# to accommodate pre-fetching inside the kernel
pad = torch.zeros(20, device=incs.device, dtype=incs.dtype)
incs = torch.cat((incs, pad))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
return lut, width
# -----------------------------
# Dense = Dense x Sparse (DDS)
# -----------------------------
# AB = (B^T A^T)^T
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
##############
# MAIN API #
##############
class _matmul(torch.autograd.Function):
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
@staticmethod
def forward(
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_lut = da_lut
ctx.da_width = da_width
ctx.db_lut = db_lut
ctx.db_width = db_width
ctx.mode = mode
ctx.spdims = spdims
ctx.block = block
ctx.trans_a = trans_a
ctx.trans_b = trans_b
ctx.trans_c = trans_c
ctx.has_out = out is not None
return c
@staticmethod
def backward(ctx, dc):
# saved for backward
a, b = ctx.saved_tensors
da, db = None, None
mode = ctx.mode
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
)
dout = dc if ctx.has_out else None
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, dout
class matmul:
def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
self.block = block
self.mode = mode
self.trans_a = trans_a
self.trans_b = trans_b
self.trans_c = trans_c
self.layout = layout
self.spdims = layout.shape
step = min(block, 32)
if self.mode == 'sdd':
self.c_lut, self.c_width = sdd_lut(layout, block, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device)
if self.mode == 'dsd':
self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device)
self.da_lut, self.da_width = sdd_lut(layout, block, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
if self.mode == 'dds':
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
self.db_lut, self.db_width = sdd_lut(layout, block, device)
def __call__(self, a, b, out=None):
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width,
self.da_lut, self.da_width,
self.db_lut, self.db_width,
out
)
return c

View File

@@ -0,0 +1,239 @@
import torch
import triton
import triton.language as tl
def num_warps(n):
if n <= 128:
return 1
if n <= 256:
return 2
if n <= 512:
return 4
if n <= 4096:
return 8
return 16
@triton.jit
def _blocksparse_softmax_fwd(
Out, A, stride_xz, LUT,
R, extent, stride_zr, stride_hr, # relative attention
scale, is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
# create index ranges
hm = h * tl.num_programs(1) + m
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
# extract information from LUT
header = LUT + (hm // BLOCK_SIZE) * 2
size = tl.load(header + 0)
offset = tl.load(header + 1)
# pointer offset
off_a = z * stride_xz
off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx
off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx
# do not need to read column indices in the dense case
if IS_DENSE:
ns = tl.arange(0, ROW_SIZE)
else:
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
ns = start_n * BLOCK_SIZE + lane_n
# load X
mask = block_n < size
a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf"))
a = a.to(tl.float32)
# compute
out = a
out *= scale
# apply relative attention
if R is not None:
R += z * stride_zr
R += h * stride_hr
off_lo = (extent - m - 1) + ns
mask_lo = (off_lo >= 0) & (off_lo < extent)
rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)
out += rel_logits
out = out.to(tl.float32)
# apply causal mask
out = tl.where((ns > m) & is_causal, -float("inf"), out)
# computation
out = tl.softmax(out)
# write-back
tl.store(Out + off_a + lane_n, out, mask=mask)
@triton.jit
def _blocksparse_softmax_bwd(
DA, stride_zdx,
DOut, stride_zdout,
Out, stride_zout,
scale,
LUT,
DR, extent, stride_zr, stride_hr, stride_er,
is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
# create index ranges
hm = h * tl.num_programs(1) + m
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
# extract information from LUT
header = LUT + (hm // BLOCK_SIZE) * 2
size = tl.load(header + 0)
offset = tl.load(header + 1)
# row-col offset
off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE
off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE
mask = block_n < size
# pointers
As = Out + z * stride_zout + off_mn
DOuts = DOut + z * stride_zdout + off_mn
# do not need to read column indices in the dense case
if IS_DENSE:
ns = tl.arange(0, ROW_SIZE)
else:
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)
ns = start_n * BLOCK_SIZE + lane_n
# load data
a = tl.load(As + lane_n, mask=mask, other=0.0)
a = a.to(tl.float32)
dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)
dout = dout.to(tl.float32)
# compute
a = tl.where((ns > m) & is_causal & (a == a), 0., a)
da = a * (dout - tl.sum(a * dout, 0))
# apply relative attention
if DR is not None:
DR += z * stride_zr
DR += h * stride_hr
off_lo = (extent - m - 1) + ns
mask_lo = (off_lo >= 0) & (off_lo < extent) & mask
tl.store(DR + m * extent + off_lo, da, mask=mask_lo)
da = da * scale
# convert da
# write-back
DAs = DA + z * stride_zdx + off_mn
tl.store(DAs + lane_n, da, mask=mask)
class _softmax(torch.autograd.Function):
@staticmethod
def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
sizes = _empty.clone()
# sizes along rows
for h in range(layout.shape[0]):
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
total_sizes = sizes * block
# offsets in block format
offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
# block indices
columns = layout.nonzero(as_tuple=False)[:, 2]
header = torch.stack((sizes, offsets), dim=1).view(-1)
lut = torch.cat((header, columns)).type(torch.int32).to(device)
return lut, int(total_sizes.max())
@staticmethod
def forward(
ctx, a, scale, rel_logits, is_causal,
spdims, block, lut, maxlut, is_dense
):
if scale is not None and isinstance(scale, torch.Tensor):
assert scale.device.type == "cpu"
scale = scale.item()
M = a.shape[0]
grid = [spdims[0], spdims[1] * block, M]
rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape
rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()
# enqueue kernel
out = torch.empty_like(a)
_blocksparse_softmax_fwd[grid](
out, a, a.stride(0), lut,
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
scale,
is_causal,
BLOCK_SIZE=block,
ROW_SIZE=triton.next_power_of_2(maxlut),
IS_DENSE=is_dense,
num_warps=num_warps(maxlut)
)
# save to context
# ctx.mark_dirty(x)
ctx.save_for_backward(out, lut)
ctx.spdims = spdims
ctx.block = block
ctx.maxlut = maxlut
ctx.scale = scale
ctx.rel_shape = rel_shape
ctx.rel_strides = rel_strides
ctx.rel_dtype = a.dtype
ctx.is_dense = is_dense
ctx.is_causal = is_causal
return out
@staticmethod
def backward(ctx, dout):
# retrieve from context
out, lut = ctx.saved_tensors
# relative logits gradients
dr = None
if ctx.needs_input_grad[3]:
dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)
# run kernel
M = out.shape[0]
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
da = torch.empty_like(dout)
_blocksparse_softmax_bwd[grid](
da, da.stride(0),
dout, dout.stride(0),
out, out.stride(0),
ctx.scale,
lut,
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
ctx.is_causal,
BLOCK_SIZE=ctx.block,
ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
IS_DENSE=ctx.is_dense,
num_warps=num_warps(ctx.maxlut)
)
return (da, None, None, dr, None,
None, None, None, None, None,
None,
None, None, None,
None,
None, None, None
)
class softmax:
def __init__(self, layout, block, device, is_dense=False):
self.spdims = layout.shape
self.layout = layout
self.block = block
self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)
self.is_dense = is_dense
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
if rel_logits is not None and rel_logits.dtype != a.dtype:
raise ValueError(f"relative position embedding must be {a.dtype}")
a = _softmax.apply(
a, scale, rel_logits, is_causal,
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
)
return a