First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
# from .conv import _conv, conv
from . import blocksparse
from .cross_entropy import _cross_entropy, cross_entropy
from .flash_attention import attention
from .matmul import _matmul, matmul
from .bmm_matmul import _bmm, bmm
__all__ = [
"blocksparse",
"_cross_entropy",
"cross_entropy",
"_matmul",
"matmul",
"_bmm",
"bmm",
"attention",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,7 @@
from .matmul import matmul
from .softmax import softmax
__all__ = [
"matmul",
"softmax",
]

View File

@@ -0,0 +1,437 @@
import torch
import triton
import triton.language as tl
# ********************************************************
# --------------------------------------------------------
# Sparse = Dense x Dense (SDD)
# This operation uses super-blocking to make sure that
# it's done efficiently when small blocks can be grouped
# together
# --------------------------------------------------------
# ********************************************************
@triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
})
@triton.jit
def _sdd_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
):
# ------------ #
# - Prologue - #
# ------------ #
block_id = tl.program_id(0) + grid_offset
lut += block_id * 3
# offsets
off_z = tl.program_id(2) # batch
off_h = tl.load(lut + 0) # head
# initialize pointers to A
start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K)
a_ptrs = A \
+ off_z * stride_za \
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B
start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K)
b_ptrs = B \
+ off_z * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_nb \
+ offs_bk[:, None] * stride_bk
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(K, 0, -TILE_K):
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
acc += tl.dot(a, b, out_dtype=tl.float32)
a_ptrs += TILE_K * stride_ak
b_ptrs += TILE_K * stride_bk
c = acc.to(C.dtype.element_ty)
# ---------------- #
# Epilogue #
# ---------------- #
offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C \
+ off_z * stride_zc \
+ block_id * stride_hc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
tl.store(pc, c, mask=True)
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# (A * B)^T = B^T * A^T
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
# shape constraints
a_dim = -2 if trans_a else -1
b_dim = -1 if trans_b else -2
Ka, Kb = a.shape[a_dim], b.shape[b_dim]
if Ka != Kb:
raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
# allocate output
if out is None:
c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)
else:
assert out.shape == (a.shape[0], lut.shape[0], block, block)
c = out
grid = [c.shape[1], 1, c.shape[0]]
_sdd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, 0, lut,
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
num_warps=4,
)
return c
def sdd_lut(layout, block, device):
lut = layout.nonzero(as_tuple=False).to(device).int()
lut = lut.contiguous()
return lut, None
# -----------------------------
# Dense = Sparse x Dense (DSD)
# This operation uses a look-up table that contains pre-computed pointer increments
# in order to minimize computations in the inner loop of the matmul kernel.
# -----------------------------
@triton.jit
def _dsd_kernel(
A, B, C,
stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
):
# ------------ #
# - Prologue - #
# ------------ #
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pidz = tl.program_id(2)
header = lut + pid_n * 4
offset = tl.load(header + 0)
K = tl.load(header + 1)
column = tl.load(header + 2)
off_h = tl.load(header + 3)
pinc = lut + offset
# initialize pointers to A (sparse)
block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8) # compiler hint
offs_am = tl.arange(0, TILE_M)
offs_ak = tl.arange(0, TILE_K)
pa = A + pidz * stride_az \
+ block_id * stride_ha \
+ offs_am[:, None] * stride_am \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B (dense)
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
start_bk = tl.load(pinc)
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
offs_bk = start_bk + tl.arange(0, TILE_K)
pb = B + pidz * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
for k in range(K, 0, -TILE_K):
a = tl.load(pa)
b = tl.load(pb)
acc += tl.dot(a, b, out_dtype=tl.float32)
pa += inc_a
pb += inc_b * stride_bk
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
c = acc.to(C.dtype.element_ty)
# initialize pointers to C
offs_cm = column * TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
pc = C \
+ off_h * stride_hc \
+ pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# shapes / dtypes
AS1 = block * spdims[2 if trans_a else 1]
BS0 = b.size(0)
BS1 = b.size(1)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# allocate output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
if out is None:
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
else:
assert out.shape == (CS0, CS1, CS2, CS3)
c = out
# meta-parameter heuristics
TILE_N = 128
# compute output
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
_dsd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut,
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
)
# exit()
return c
def dsd_lut(layout, block, step, trans, device):
"""
Generates the look-up table for incrementing pointers in the DSD/DDS matmul.
Example (BLOCK=32, STEP=16)
[[1, 0, 0, 1, 0],
[0, 1, 1, 0, 1],
[1, 0, 1, 0, 0]]
Then the offsets for A are
[0 , 16, 32, 48] <- row 0
\\----/ \\----/
col=0 col=3
[64, 80, 96, 112, 128, 144] <- row 1
\\----/ \\----/ \\------/
col=1 col=2 col=3
[160, 176, 192, 208]
which leads to increments table
[0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16]
Because B is dense, the offsets are
[0, 16, 96, 112] <- row 0
[32, 48, 64, 80] <- row 1
[0, 16, 64, 80] <- row 2
"""
sizes = torch.sum(layout, 2 if trans else 1)
head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True)
sizes = sizes.flatten()
segments = sizes * step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
# -------------------------------
# dense input pointer increments
# -------------------------------
# Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
# that is smaller than the block size, so we need to do a bit of extra work
# to handle this case
B_idx = nnz[:, 2] * block
B_incs = B_idx.clone()
B_incs[1:] -= B_idx[:-1]
div = block // step
B_incs = B_incs.view(-1, 1).repeat(1, div)
B_incs[:, 1:] = step
B_incs[:, 0] -= (div - 1) * step
# first increment for each reduction is actually the offset
B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]]
B_incs = B_incs.view(-1)
# -------------------------------
# sparse input pointer increments
# -------------------------------
# same as above, except that the increments are in the sparse memory layout
if trans:
A_idx = torch.arange(num_blocks, device=layout.device)
else:
A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
current_offset = 0
for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone().long()
msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device)
A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
A_incs = A_idx * block * block
A_incs[1:] -= A_idx[:-1] * block * block
A_incs = A_incs.view(-1, 1).repeat(1, div)
if trans:
A_incs[:, 1:] = step
A_incs[:, 0] -= (div - 1) * step
else:
A_incs[:, 1:] = step * block
A_incs[:, 0] -= (div - 1) * step * block
A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]]
A_incs = A_incs.view(-1)
# create header
width = col_id.size(0)
offsets = offsets * 2 * div + 4 * width
segments = segments * div
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
# create increments
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
# pad by a factor 2*MAX_NUM_STAGES
# to accommodate pre-fetching inside the kernel
pad = torch.zeros(20, device=incs.device, dtype=incs.dtype)
incs = torch.cat((incs, pad))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
return lut, width
# -----------------------------
# Dense = Dense x Sparse (DDS)
# -----------------------------
# AB = (B^T A^T)^T
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
##############
# MAIN API #
##############
class _matmul(torch.autograd.Function):
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
@staticmethod
def forward(
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_lut = da_lut
ctx.da_width = da_width
ctx.db_lut = db_lut
ctx.db_width = db_width
ctx.mode = mode
ctx.spdims = spdims
ctx.block = block
ctx.trans_a = trans_a
ctx.trans_b = trans_b
ctx.trans_c = trans_c
ctx.has_out = out is not None
return c
@staticmethod
def backward(ctx, dc):
# saved for backward
a, b = ctx.saved_tensors
da, db = None, None
mode = ctx.mode
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
)
dout = dc if ctx.has_out else None
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, dout
class matmul:
def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
self.block = block
self.mode = mode
self.trans_a = trans_a
self.trans_b = trans_b
self.trans_c = trans_c
self.layout = layout
self.spdims = layout.shape
step = min(block, 32)
if self.mode == 'sdd':
self.c_lut, self.c_width = sdd_lut(layout, block, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device)
if self.mode == 'dsd':
self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device)
self.da_lut, self.da_width = sdd_lut(layout, block, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
if self.mode == 'dds':
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
self.db_lut, self.db_width = sdd_lut(layout, block, device)
def __call__(self, a, b, out=None):
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width,
self.da_lut, self.da_width,
self.db_lut, self.db_width,
out
)
return c

View File

@@ -0,0 +1,239 @@
import torch
import triton
import triton.language as tl
def num_warps(n):
if n <= 128:
return 1
if n <= 256:
return 2
if n <= 512:
return 4
if n <= 4096:
return 8
return 16
@triton.jit
def _blocksparse_softmax_fwd(
Out, A, stride_xz, LUT,
R, extent, stride_zr, stride_hr, # relative attention
scale, is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
# create index ranges
hm = h * tl.num_programs(1) + m
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
# extract information from LUT
header = LUT + (hm // BLOCK_SIZE) * 2
size = tl.load(header + 0)
offset = tl.load(header + 1)
# pointer offset
off_a = z * stride_xz
off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx
off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx
# do not need to read column indices in the dense case
if IS_DENSE:
ns = tl.arange(0, ROW_SIZE)
else:
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
ns = start_n * BLOCK_SIZE + lane_n
# load X
mask = block_n < size
a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf"))
a = a.to(tl.float32)
# compute
out = a
out *= scale
# apply relative attention
if R is not None:
R += z * stride_zr
R += h * stride_hr
off_lo = (extent - m - 1) + ns
mask_lo = (off_lo >= 0) & (off_lo < extent)
rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)
out += rel_logits
out = out.to(tl.float32)
# apply causal mask
out = tl.where((ns > m) & is_causal, -float("inf"), out)
# computation
out = tl.softmax(out)
# write-back
tl.store(Out + off_a + lane_n, out, mask=mask)
@triton.jit
def _blocksparse_softmax_bwd(
DA, stride_zdx,
DOut, stride_zdout,
Out, stride_zout,
scale,
LUT,
DR, extent, stride_zr, stride_hr, stride_er,
is_causal,
ROW_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
IS_DENSE: tl.constexpr,
):
h = tl.program_id(0)
m = tl.program_id(1)
z = tl.program_id(2)
# create index ranges
hm = h * tl.num_programs(1) + m
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
# extract information from LUT
header = LUT + (hm // BLOCK_SIZE) * 2
size = tl.load(header + 0)
offset = tl.load(header + 1)
# row-col offset
off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE
off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE
mask = block_n < size
# pointers
As = Out + z * stride_zout + off_mn
DOuts = DOut + z * stride_zdout + off_mn
# do not need to read column indices in the dense case
if IS_DENSE:
ns = tl.arange(0, ROW_SIZE)
else:
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)
ns = start_n * BLOCK_SIZE + lane_n
# load data
a = tl.load(As + lane_n, mask=mask, other=0.0)
a = a.to(tl.float32)
dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)
dout = dout.to(tl.float32)
# compute
a = tl.where((ns > m) & is_causal & (a == a), 0., a)
da = a * (dout - tl.sum(a * dout, 0))
# apply relative attention
if DR is not None:
DR += z * stride_zr
DR += h * stride_hr
off_lo = (extent - m - 1) + ns
mask_lo = (off_lo >= 0) & (off_lo < extent) & mask
tl.store(DR + m * extent + off_lo, da, mask=mask_lo)
da = da * scale
# convert da
# write-back
DAs = DA + z * stride_zdx + off_mn
tl.store(DAs + lane_n, da, mask=mask)
class _softmax(torch.autograd.Function):
@staticmethod
def make_lut(layout, block, device):
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
sizes = _empty.clone()
# sizes along rows
for h in range(layout.shape[0]):
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
total_sizes = sizes * block
# offsets in block format
offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
# block indices
columns = layout.nonzero(as_tuple=False)[:, 2]
header = torch.stack((sizes, offsets), dim=1).view(-1)
lut = torch.cat((header, columns)).type(torch.int32).to(device)
return lut, int(total_sizes.max())
@staticmethod
def forward(
ctx, a, scale, rel_logits, is_causal,
spdims, block, lut, maxlut, is_dense
):
if scale is not None and isinstance(scale, torch.Tensor):
assert scale.device.type == "cpu"
scale = scale.item()
M = a.shape[0]
grid = [spdims[0], spdims[1] * block, M]
rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape
rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()
# enqueue kernel
out = torch.empty_like(a)
_blocksparse_softmax_fwd[grid](
out, a, a.stride(0), lut,
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
scale,
is_causal,
BLOCK_SIZE=block,
ROW_SIZE=triton.next_power_of_2(maxlut),
IS_DENSE=is_dense,
num_warps=num_warps(maxlut)
)
# save to context
# ctx.mark_dirty(x)
ctx.save_for_backward(out, lut)
ctx.spdims = spdims
ctx.block = block
ctx.maxlut = maxlut
ctx.scale = scale
ctx.rel_shape = rel_shape
ctx.rel_strides = rel_strides
ctx.rel_dtype = a.dtype
ctx.is_dense = is_dense
ctx.is_causal = is_causal
return out
@staticmethod
def backward(ctx, dout):
# retrieve from context
out, lut = ctx.saved_tensors
# relative logits gradients
dr = None
if ctx.needs_input_grad[3]:
dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)
# run kernel
M = out.shape[0]
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
da = torch.empty_like(dout)
_blocksparse_softmax_bwd[grid](
da, da.stride(0),
dout, dout.stride(0),
out, out.stride(0),
ctx.scale,
lut,
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
ctx.is_causal,
BLOCK_SIZE=ctx.block,
ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
IS_DENSE=ctx.is_dense,
num_warps=num_warps(ctx.maxlut)
)
return (da, None, None, dr, None,
None, None, None, None, None,
None,
None, None, None,
None,
None, None, None
)
class softmax:
def __init__(self, layout, block, device, is_dense=False):
self.spdims = layout.shape
self.layout = layout
self.block = block
self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)
self.is_dense = is_dense
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
if rel_logits is not None and rel_logits.dtype != a.dtype:
raise ValueError(f"relative position embedding must be {a.dtype}")
a = _softmax.apply(
a, scale, rel_logits, is_causal,
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
)
return a

View File

@@ -0,0 +1,163 @@
import torch
import triton
import triton.language as tl
from .matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [1]:
# TODO support block size 16 for MFMA dot op
for block_m in [16, 32] if torch.version.hip is None and not hasattr(torch, "corex") else [32, 64]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 4 if block_n <= 64 else 8
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
#for split_k in [2, 4, 8, 16]:
# configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
# num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
def get_configs_compute_bound():
configs = []
for block_m in [64, 128, 256]:
for block_n in [64, 128, 256]:
for block_k in [32, 64, 128]:
num_warps = 8 if block_n <= 64 else 16
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=1, num_warps=num_warps))
return configs
@triton.autotune(
configs=[
] + get_configs_compute_bound() + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % args['BLOCK_K'] == 0,
})
@triton.jit
def _bmm_kernel(A, B, C, M, N, K,
stride_aq, stride_am, stride_ak,
stride_bq, stride_bk, stride_bn,
stride_cq, stride_cm, stride_cn,
dot_out_dtype: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
):
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
idx_q = tl.program_id(1) # batch dimension for BMM
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_q = tl.program_id(1) # batch dimension for BMM
idx_m = rm[:, None]
idx_n = rn[None, :]
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn + idx_q * stride_cq)
mask = (idx_m < M) & (idx_n < N)
# handles write-back with reduction-splitting
tl.store(C, acc, mask=mask)
class _bmm(torch.autograd.Function):
kernel = _bmm_kernel
_locks = {}
@staticmethod
def _call(a, b, dot_out_dtype):
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
#only MR support Trans layout
if hasattr(torch, "corex"):
capability = torch.cuda.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
if (capability < 71):
if a.stride(0) >= 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) >= 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[0] == b.shape[0], "incompatible dimensions"
assert a.shape[2] == b.shape[1], "incompatible dimensions"
B, M, K = a.shape
_, _, N = b.shape
# allocates output
c = torch.empty((B, M, N), device=device, dtype=a.dtype)
if dot_out_dtype is None:
if a.dtype in [torch.float16, torch.float32, torch.bfloat16]:
dot_out_dtype = tl.float32
else:
dot_out_dtype = tl.int32
else:
assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype"
if dot_out_dtype == torch.float16:
dot_out_dtype = tl.float16
elif dot_out_dtype in [torch.float32, torch.bfloat16]:
dot_out_dtype = tl.float32
else:
dot_out_dtype = tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), B, 1)
_bmm_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1), a.stride(2),
b.stride(0), b.stride(1), b.stride(2),
c.stride(0), c.stride(1), c.stride(2),
dot_out_dtype=dot_out_dtype,
GROUP_M=8)
return c
@staticmethod
def forward(ctx, a, b, dot_out_dtype=None):
return _bmm._call(a, b, dot_out_dtype=dot_out_dtype)
bmm = _bmm.apply

View File

@@ -0,0 +1,94 @@
import torch
import triton
import triton.language as tl
def num_warps(N):
if N < 2048:
return 4
elif N < 8192:
return 8
return 16
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
@triton.jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
idx = tl.load(IDX + row)
# pointers to logit and probs
LOGITS = LOGITS + row * N + cols
WRIT_PROBS = PROBS + row * N + cols
READ_PROBS = PROBS + row * N + idx
# write-back negative log-probs
logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
logits = logits.to(tl.float32)
logits = logits - tl.max(logits, 0)
probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits
tl.store(WRIT_PROBS, probs, mask=cols < N)
# There is a bug in the compiler, which fails to insert a barrier here.
# We add it explicitly for now. Will be fixed soon.
tl.debug_barrier()
# write-back loss
probs = tl.load(READ_PROBS)
tl.store(LOSS + row, probs)
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
@triton.jit
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
idx = tl.load(IDX + row)
# pointers to probs
PROBS = PROBS + row * N + cols
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
# and we have -log(p[k]) stored in PROBS, so this is easy
probs = -tl.load(PROBS, mask=cols < N, other=float('inf'))
probs = tl.exp(probs.to(tl.float32))
delta = cols == idx
# write result in-place in PROBS
dout = tl.load(DPROBS + row)
din = (probs - delta) * dout
tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N)
class _cross_entropy(torch.autograd.Function):
@classmethod
def forward(cls, ctx, logits, indices):
# make sure we can use triton
assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
# make kernel
device, dtype = logits.device, logits.dtype
n_cols = logits.shape[-1]
# run the kernel
result = torch.empty_like(indices, dtype=dtype, device=device)
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
grid = lambda opt: (logits.numel() // n_cols, )
_forward[grid](logits, neg_logprobs, indices, result, n_cols)
# save for backward
ctx.save_for_backward(neg_logprobs, indices)
return result
@classmethod
def backward(cls, ctx, dneg_logprobs):
"""We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
so we initialize the gradient as neg_logprobs, so we can just exponentiate
to get p[k], which is most of what we need... neg_logprobs will be
modified in place to become the gradient we want
"""
# load saved tensors
neg_logprobs, indices = ctx.saved_tensors
# run the kernel
# neg_logprobs will be modified in place to become our gradient:
n_cols = neg_logprobs.shape[-1]
grid = lambda opt: (neg_logprobs.numel() // n_cols, )
_backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols)
return neg_logprobs, None
cross_entropy = _cross_entropy.apply

View File

@@ -0,0 +1,271 @@
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention algorithm
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
"""
import torch
import triton
import triton.language as tl
from triton.common.build import is_corex
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
L, M,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
# -- compute qk ----
k = tl.load(k_ptrs)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
l_prev *= tl.exp(m_prev - m_curr)
# attention weights
p = tl.exp(qk - m_curr[:, None])
l_curr = tl.sum(p, 1) + l_prev
# rescale operands of matmuls
l_rcp = 1. / l_curr
p *= l_rcp[:, None]
acc *= (l_prev * l_rcp)[:, None]
# update acc
p = p.to(Q.dtype.element_ty)
v = tl.load(v_ptrs)
acc += tl.dot(p, v)
# update m_i and l_i
l_prev = l_curr
m_prev = m_curr
# update pointers
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_prev)
tl.store(m_ptrs, m_prev)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
@triton.jit
def _bwd_preprocess(
Out, DO, L,
NewDO, Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
denom = tl.load(L + off_m).to(tl.float32)
# compute
do = do / denom[:, None]
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
tl.store(dq_ptrs, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, sm_scale):
# only support for Ampere now
capability = torch.cuda.get_device_capability()
if not is_corex():
if capability[0] < 8:
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
BLOCK = 128
else:
BLOCK = 64 # FIXME: currently BLOCK=128 has issues, BLOCK=64 works for common cases.
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
num_warps = 4
_fwd_kernel[grid](
q, k, v, sm_scale,
L, m,
o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=Lk, num_warps=num_warps,
num_stages=2 if not is_corex() else 1,
)
ctx.save_for_backward(q, k, v, o, L, m)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
return o
@staticmethod
def backward(ctx, do):
BLOCK = 128 if not is_corex() else 64 # FIXME: currently BLOCK=128 has issues, BLOCK=64 works for common cases.
num_warps = 16 if is_corex() and ctx.BLOCK_DMODEL > 64 else 8
q, k, v, o, l, m = ctx.saved_tensors
do = do.contiguous()
dq = torch.zeros_like(q, dtype=torch.float32)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
do_scaled = torch.empty_like(do)
delta = torch.empty_like(l)
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, l,
do_scaled, delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
l, m,
delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
num_stages=1,
)
return dq, dk, dv, None
attention = _attention.apply

184
pkgs/triton/ops/matmul.py Normal file
View File

@@ -0,0 +1,184 @@
import torch
import triton
import triton.language as tl
from .matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
if hasattr(torch, "corex"):
return configs
for num_stages in [1, 2]:
# TODO support block size 16 for MFMA dot op
for block_m in [16, 32] if torch.version.hip is None and not hasattr(torch, "corex") else [32, 64]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
def get_configs_compute_bound():
configs = []
if hasattr(torch, "corex"):
for block_m in [32, 64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for num_stages in [1, 2]:
num_warps = 16 if block_m >= 128 or block_n >=128 or block_k >= 128 else 8
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
return configs
def get_nv_config():
configs = []
if hasattr(torch, "corex"):
return configs
configs = [# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
]
return configs
@triton.autotune(
configs=get_nv_config() + get_configs_compute_bound() + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
dot_out_dtype: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b, out_dtype=dot_out_dtype)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
class _matmul(torch.autograd.Function):
kernel = _kernel
_locks = {}
@staticmethod
def _call(a, b, dot_out_dtype):
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=a.dtype)
if dot_out_dtype is None:
if a.dtype in [torch.float16, torch.float32, torch.bfloat16]:
dot_out_dtype = tl.float32
else:
dot_out_dtype = tl.int32
else:
assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype"
if dot_out_dtype == torch.float16:
dot_out_dtype = tl.float16
elif dot_out_dtype in [torch.float32, torch.bfloat16]:
dot_out_dtype = tl.float32
else:
dot_out_dtype = tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
dot_out_dtype=dot_out_dtype,
GROUP_M=8)
return c
@staticmethod
def forward(ctx, a, b, dot_out_dtype=None):
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype)
matmul = _matmul.apply

View File

@@ -0,0 +1,164 @@
import heapq
import torch
import triton
import triton._C.libtriton.triton as _triton
from triton.runtime import driver
from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
return tflops
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
''' return compute throughput in TOPS '''
total_warps = num_ctas * min(num_warps, 4)
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
return tflops
def get_tflops(backend, device, num_ctas, num_warps, dtype):
capability = torch.cuda.get_device_capability(device)
if capability[0] < 8 and dtype == torch.float32:
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
def estimate_matmul_time(
# backend, device,
num_warps, num_stages,
A, B, C,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
debug=False, **kwargs
):
''' return estimated running time in ms
= max(compute, loading) + store '''
backend = _triton.runtime.backend.CUDA
device = torch.cuda.current_device()
dtype = A.dtype
dtsize = A.element_size()
num_cta_m = triton.cdiv(M, BLOCK_M)
num_cta_n = triton.cdiv(N, BLOCK_N)
num_cta_k = SPLIT_K
num_ctas = num_cta_m * num_cta_n * num_cta_k
# If the input is smaller than the block size
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
# time to compute
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
tput = get_tflops(backend, device, num_ctas, num_warps, dtype)
compute_ms = total_ops / tput
# time to load data
num_sm = driver.utils.get_device_properties(device)["multiprocessor_count"]
active_cta_ratio = min(1, num_ctas / num_sm)
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
# assume 80% of (following) loads are in L2 cache
load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
# total
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
# loading time in ms
load_ms = total_dram / dram_bw + total_l2 / l2_bw
# estimate storing time
store_bw = dram_bw * 0.6 # :o
store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
if SPLIT_K == 1:
store_ms = store_c_dram / store_bw
else:
reduce_bw = store_bw
store_ms = store_c_dram / reduce_bw
# c.zero_()
zero_ms = M * N * 2 / (1024 * 1024) / store_bw
store_ms += zero_ms
total_time_ms = compute_ms + load_ms + store_ms
if debug:
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
f'Activate CTAs: {active_cta_ratio*100}%')
return total_time_ms
def early_config_prune(configs, named_args):
device = torch.cuda.current_device()
capability = torch.cuda.get_device_capability()
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
dtsize = named_args['A'].element_size()
dtype = named_args['A'].dtype
# 1. make sure we have enough smem
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
max_shared_memory = driver.utils.get_device_properties(device)["max_shared_mem"]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory <= max_shared_memory:
pruned_configs.append(config)
configs = pruned_configs
# Some dtypes do not allow atomic_add
if dtype not in [torch.float16, torch.float32]:
configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1]
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
configs_map = {}
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
if key in configs_map:
configs_map[key].append((config, num_stages))
else:
configs_map[key] = [(config, num_stages)]
pruned_configs = []
for k, v in configs_map.items():
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
if capability[0] >= 8:
# compute cycles (only works for ampere GPUs)
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
mma_cycles = mmas / min(4, num_warps) * 8
ldgsts_latency = 300 # Does this matter?
optimal_num_stages = ldgsts_latency / mma_cycles
# nearest stages, prefer large #stages
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
for n in nearest:
pruned_configs.append(n[0])
else: # Volta & Turing only supports num_stages <= 2
if hasattr(torch, "corex"):
for stage in range(len(v)):
random_config = v[stage][0]
random_config.num_stages = v[stage][1]
pruned_configs.append(random_config)
else:
random_config = v[0][0]
random_config.num_stages = 2
pruned_configs.append(random_config)
return pruned_configs