First commit
This commit is contained in:
17
pkgs/triton/ops/__init__.py
Normal file
17
pkgs/triton/ops/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# from .conv import _conv, conv
|
||||
from . import blocksparse
|
||||
from .cross_entropy import _cross_entropy, cross_entropy
|
||||
from .flash_attention import attention
|
||||
from .matmul import _matmul, matmul
|
||||
from .bmm_matmul import _bmm, bmm
|
||||
|
||||
__all__ = [
|
||||
"blocksparse",
|
||||
"_cross_entropy",
|
||||
"cross_entropy",
|
||||
"_matmul",
|
||||
"matmul",
|
||||
"_bmm",
|
||||
"bmm",
|
||||
"attention",
|
||||
]
|
||||
BIN
pkgs/triton/ops/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/__pycache__/bmm_matmul.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/__pycache__/bmm_matmul.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/__pycache__/cross_entropy.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/__pycache__/cross_entropy.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/__pycache__/flash_attention.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/__pycache__/flash_attention.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/__pycache__/matmul.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/__pycache__/matmul.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc
Normal file
Binary file not shown.
7
pkgs/triton/ops/blocksparse/__init__.py
Normal file
7
pkgs/triton/ops/blocksparse/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .matmul import matmul
|
||||
from .softmax import softmax
|
||||
|
||||
__all__ = [
|
||||
"matmul",
|
||||
"softmax",
|
||||
]
|
||||
BIN
pkgs/triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc
Normal file
Binary file not shown.
BIN
pkgs/triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc
Normal file
BIN
pkgs/triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc
Normal file
Binary file not shown.
437
pkgs/triton/ops/blocksparse/matmul.py
Normal file
437
pkgs/triton/ops/blocksparse/matmul.py
Normal file
@@ -0,0 +1,437 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# ********************************************************
|
||||
# --------------------------------------------------------
|
||||
# Sparse = Dense x Dense (SDD)
|
||||
# This operation uses super-blocking to make sure that
|
||||
# it's done efficiently when small blocks can be grouped
|
||||
# together
|
||||
# --------------------------------------------------------
|
||||
# ********************************************************
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _sdd_kernel(
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_nb,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
K, grid_offset, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
block_id = tl.program_id(0) + grid_offset
|
||||
lut += block_id * 3
|
||||
# offsets
|
||||
off_z = tl.program_id(2) # batch
|
||||
off_h = tl.load(lut + 0) # head
|
||||
|
||||
# initialize pointers to A
|
||||
start_am = tl.load(lut + 1)
|
||||
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
|
||||
offs_ak = tl.arange(0, TILE_K)
|
||||
a_ptrs = A \
|
||||
+ off_z * stride_za \
|
||||
+ off_h * stride_ha \
|
||||
+ offs_am[:, None] * stride_ma \
|
||||
+ offs_ak[None, :] * stride_ak
|
||||
# initialize pointers to B
|
||||
start_bn = tl.load(lut + 2)
|
||||
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
|
||||
offs_bk = tl.arange(0, TILE_K)
|
||||
b_ptrs = B \
|
||||
+ off_z * stride_zb \
|
||||
+ off_h * stride_hb \
|
||||
+ offs_bn[None, :] * stride_nb \
|
||||
+ offs_bk[:, None] * stride_bk
|
||||
# ---------------- #
|
||||
# Inner Loop #
|
||||
# ---------------- #
|
||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||
for k in range(K, 0, -TILE_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
else:
|
||||
a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
|
||||
b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
|
||||
acc += tl.dot(a, b, out_dtype=tl.float32)
|
||||
a_ptrs += TILE_K * stride_ak
|
||||
b_ptrs += TILE_K * stride_bk
|
||||
c = acc.to(C.dtype.element_ty)
|
||||
# ---------------- #
|
||||
# Epilogue #
|
||||
# ---------------- #
|
||||
offs_cm = tl.arange(0, TILE_M) % BLOCK
|
||||
offs_cn = tl.arange(0, TILE_N) % BLOCK
|
||||
pc = C \
|
||||
+ off_z * stride_zc \
|
||||
+ block_id * stride_hc \
|
||||
+ offs_cm[:, None] * stride_mc \
|
||||
+ offs_cn[None, :] * stride_nc
|
||||
tl.store(pc, c, mask=True)
|
||||
|
||||
|
||||
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
b = b.contiguous()
|
||||
# (A * B)^T = B^T * A^T
|
||||
if trans_c:
|
||||
a, b = b, a
|
||||
trans_a, trans_b = not trans_b, not trans_a
|
||||
# shape constraints
|
||||
a_dim = -2 if trans_a else -1
|
||||
b_dim = -1 if trans_b else -2
|
||||
Ka, Kb = a.shape[a_dim], b.shape[b_dim]
|
||||
if Ka != Kb:
|
||||
raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
|
||||
# allocate output
|
||||
if out is None:
|
||||
c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)
|
||||
else:
|
||||
assert out.shape == (a.shape[0], lut.shape[0], block, block)
|
||||
c = out
|
||||
grid = [c.shape[1], 1, c.shape[0]]
|
||||
_sdd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
Ka, 0, lut,
|
||||
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
|
||||
num_warps=4,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
def sdd_lut(layout, block, device):
|
||||
lut = layout.nonzero(as_tuple=False).to(device).int()
|
||||
lut = lut.contiguous()
|
||||
return lut, None
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Sparse x Dense (DSD)
|
||||
# This operation uses a look-up table that contains pre-computed pointer increments
|
||||
# in order to minimize computations in the inner loop of the matmul kernel.
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dsd_kernel(
|
||||
A, B, C,
|
||||
stride_az, stride_ha, stride_am, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_cm, stride_cn,
|
||||
DS0, DS1, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
num_pid_m = tl.num_programs(0)
|
||||
num_pid_n = tl.num_programs(1)
|
||||
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
|
||||
pidz = tl.program_id(2)
|
||||
header = lut + pid_n * 4
|
||||
offset = tl.load(header + 0)
|
||||
K = tl.load(header + 1)
|
||||
column = tl.load(header + 2)
|
||||
off_h = tl.load(header + 3)
|
||||
pinc = lut + offset
|
||||
# initialize pointers to A (sparse)
|
||||
block_id = tl.load(pinc + 1)
|
||||
block_id = tl.multiple_of(block_id, 8) # compiler hint
|
||||
offs_am = tl.arange(0, TILE_M)
|
||||
offs_ak = tl.arange(0, TILE_K)
|
||||
pa = A + pidz * stride_az \
|
||||
+ block_id * stride_ha \
|
||||
+ offs_am[:, None] * stride_am \
|
||||
+ offs_ak[None, :] * stride_ak
|
||||
# initialize pointers to B (dense)
|
||||
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
|
||||
start_bk = tl.load(pinc)
|
||||
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
|
||||
offs_bk = start_bk + tl.arange(0, TILE_K)
|
||||
pb = B + pidz * stride_zb \
|
||||
+ off_h * stride_hb \
|
||||
+ offs_bn[None, :] * stride_bn \
|
||||
+ offs_bk[:, None] * stride_bk
|
||||
# ---------------- #
|
||||
# Inner Loop #
|
||||
# ---------------- #
|
||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||
pinc += 2
|
||||
inc_a = tl.load(pinc + 1)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
inc_b = tl.load(pinc)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
for k in range(K, 0, -TILE_K):
|
||||
a = tl.load(pa)
|
||||
b = tl.load(pb)
|
||||
acc += tl.dot(a, b, out_dtype=tl.float32)
|
||||
pa += inc_a
|
||||
pb += inc_b * stride_bk
|
||||
pinc += 2
|
||||
inc_a = tl.load(pinc + 1)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
inc_b = tl.load(pinc)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
c = acc.to(C.dtype.element_ty)
|
||||
# initialize pointers to C
|
||||
offs_cm = column * TILE_M + tl.arange(0, TILE_M)
|
||||
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
|
||||
pc = C \
|
||||
+ off_h * stride_hc \
|
||||
+ pidz * stride_zc \
|
||||
+ offs_cm[:, None] * stride_cm \
|
||||
+ offs_cn[None, :] * stride_cn
|
||||
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
|
||||
|
||||
|
||||
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
b = b.contiguous()
|
||||
# shapes / dtypes
|
||||
AS1 = block * spdims[2 if trans_a else 1]
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
# allocate output
|
||||
CS0 = BS0
|
||||
CS1 = BS1
|
||||
CS2 = BS3 if trans_c else AS1
|
||||
CS3 = AS1 if trans_c else BS3
|
||||
if out is None:
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
else:
|
||||
assert out.shape == (CS0, CS1, CS2, CS3)
|
||||
c = out
|
||||
# meta-parameter heuristics
|
||||
TILE_N = 128
|
||||
# compute output
|
||||
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
|
||||
_dsd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||
BS3, AS1, lut,
|
||||
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
|
||||
num_warps=4, GROUP_SIZE_M=4,
|
||||
)
|
||||
# exit()
|
||||
return c
|
||||
|
||||
|
||||
def dsd_lut(layout, block, step, trans, device):
|
||||
"""
|
||||
Generates the look-up table for incrementing pointers in the DSD/DDS matmul.
|
||||
Example (BLOCK=32, STEP=16)
|
||||
[[1, 0, 0, 1, 0],
|
||||
[0, 1, 1, 0, 1],
|
||||
[1, 0, 1, 0, 0]]
|
||||
|
||||
Then the offsets for A are
|
||||
[0 , 16, 32, 48] <- row 0
|
||||
\\----/ \\----/
|
||||
col=0 col=3
|
||||
[64, 80, 96, 112, 128, 144] <- row 1
|
||||
\\----/ \\----/ \\------/
|
||||
col=1 col=2 col=3
|
||||
[160, 176, 192, 208]
|
||||
which leads to increments table
|
||||
[0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16]
|
||||
|
||||
Because B is dense, the offsets are
|
||||
[0, 16, 96, 112] <- row 0
|
||||
[32, 48, 64, 80] <- row 1
|
||||
[0, 16, 64, 80] <- row 2
|
||||
"""
|
||||
sizes = torch.sum(layout, 2 if trans else 1)
|
||||
head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True)
|
||||
sizes = sizes.flatten()
|
||||
segments = sizes * step
|
||||
# pointer increments
|
||||
if trans:
|
||||
nnz = layout.nonzero(as_tuple=False)
|
||||
else:
|
||||
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
|
||||
num_blocks = nnz.size(0)
|
||||
offsets = torch.zeros_like(sizes)
|
||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
|
||||
# -------------------------------
|
||||
# dense input pointer increments
|
||||
# -------------------------------
|
||||
# Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
|
||||
# that is smaller than the block size, so we need to do a bit of extra work
|
||||
# to handle this case
|
||||
B_idx = nnz[:, 2] * block
|
||||
B_incs = B_idx.clone()
|
||||
B_incs[1:] -= B_idx[:-1]
|
||||
div = block // step
|
||||
B_incs = B_incs.view(-1, 1).repeat(1, div)
|
||||
B_incs[:, 1:] = step
|
||||
B_incs[:, 0] -= (div - 1) * step
|
||||
# first increment for each reduction is actually the offset
|
||||
B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]]
|
||||
B_incs = B_incs.view(-1)
|
||||
# -------------------------------
|
||||
# sparse input pointer increments
|
||||
# -------------------------------
|
||||
# same as above, except that the increments are in the sparse memory layout
|
||||
if trans:
|
||||
A_idx = torch.arange(num_blocks, device=layout.device)
|
||||
else:
|
||||
A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
current_offset = 0
|
||||
for z in range(layout.size(0)):
|
||||
layoutw = layout[z, :, :].clone().long()
|
||||
msum = layoutw.sum()
|
||||
layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device)
|
||||
A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
|
||||
current_offset += msum
|
||||
A_incs = A_idx * block * block
|
||||
A_incs[1:] -= A_idx[:-1] * block * block
|
||||
A_incs = A_incs.view(-1, 1).repeat(1, div)
|
||||
if trans:
|
||||
A_incs[:, 1:] = step
|
||||
A_incs[:, 0] -= (div - 1) * step
|
||||
else:
|
||||
A_incs[:, 1:] = step * block
|
||||
A_incs[:, 0] -= (div - 1) * step * block
|
||||
A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]]
|
||||
A_incs = A_incs.view(-1)
|
||||
# create header
|
||||
width = col_id.size(0)
|
||||
offsets = offsets * 2 * div + 4 * width
|
||||
segments = segments * div
|
||||
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
|
||||
# create increments
|
||||
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
|
||||
# pad by a factor 2*MAX_NUM_STAGES
|
||||
# to accommodate pre-fetching inside the kernel
|
||||
pad = torch.zeros(20, device=incs.device, dtype=incs.dtype)
|
||||
incs = torch.cat((incs, pad))
|
||||
# create lut
|
||||
lut = torch.cat((header, incs))
|
||||
lut = lut.type(torch.int32).to(device)
|
||||
# create locks
|
||||
return lut, width
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Dense x Sparse (DDS)
|
||||
# -----------------------------
|
||||
# AB = (B^T A^T)^T
|
||||
|
||||
|
||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out)
|
||||
|
||||
##############
|
||||
# MAIN API #
|
||||
##############
|
||||
|
||||
|
||||
class _matmul(torch.autograd.Function):
|
||||
|
||||
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
|
||||
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
|
||||
):
|
||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
|
||||
# save for backward
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.da_lut = da_lut
|
||||
ctx.da_width = da_width
|
||||
ctx.db_lut = db_lut
|
||||
ctx.db_width = db_width
|
||||
ctx.mode = mode
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.trans_a = trans_a
|
||||
ctx.trans_b = trans_b
|
||||
ctx.trans_c = trans_c
|
||||
ctx.has_out = out is not None
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dc):
|
||||
# saved for backward
|
||||
a, b = ctx.saved_tensors
|
||||
da, db = None, None
|
||||
mode = ctx.mode
|
||||
# gradients w.r.t. a
|
||||
if ctx.needs_input_grad[0]:
|
||||
mode_da = mode[1] + mode[0] + mode[2]
|
||||
da = _matmul.fn[mode_da](
|
||||
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
|
||||
)
|
||||
# gradients w.r.t. b
|
||||
if ctx.needs_input_grad[1]:
|
||||
mode_db = mode[2] + mode[1] + mode[0]
|
||||
db = _matmul.fn[mode_db](
|
||||
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
|
||||
)
|
||||
dout = dc if ctx.has_out else None
|
||||
return da, db, None, None, None,\
|
||||
None, None, None, None,\
|
||||
None, None, None, None, None, dout
|
||||
|
||||
|
||||
class matmul:
|
||||
|
||||
def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
|
||||
if mode not in ['sdd', 'dsd', 'dds']:
|
||||
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
|
||||
self.block = block
|
||||
self.mode = mode
|
||||
self.trans_a = trans_a
|
||||
self.trans_b = trans_b
|
||||
self.trans_c = trans_c
|
||||
self.layout = layout
|
||||
self.spdims = layout.shape
|
||||
step = min(block, 32)
|
||||
if self.mode == 'sdd':
|
||||
self.c_lut, self.c_width = sdd_lut(layout, block, device)
|
||||
self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device)
|
||||
self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device)
|
||||
if self.mode == 'dsd':
|
||||
self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device)
|
||||
self.da_lut, self.da_width = sdd_lut(layout, block, device)
|
||||
self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
|
||||
if self.mode == 'dds':
|
||||
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
|
||||
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
|
||||
self.db_lut, self.db_width = sdd_lut(layout, block, device)
|
||||
|
||||
def __call__(self, a, b, out=None):
|
||||
c = _matmul.apply(
|
||||
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
|
||||
self.c_lut, self.c_width,
|
||||
self.da_lut, self.da_width,
|
||||
self.db_lut, self.db_width,
|
||||
out
|
||||
)
|
||||
return c
|
||||
239
pkgs/triton/ops/blocksparse/softmax.py
Normal file
239
pkgs/triton/ops/blocksparse/softmax.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def num_warps(n):
|
||||
if n <= 128:
|
||||
return 1
|
||||
if n <= 256:
|
||||
return 2
|
||||
if n <= 512:
|
||||
return 4
|
||||
if n <= 4096:
|
||||
return 8
|
||||
return 16
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _blocksparse_softmax_fwd(
|
||||
Out, A, stride_xz, LUT,
|
||||
R, extent, stride_zr, stride_hr, # relative attention
|
||||
scale, is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
# create index ranges
|
||||
hm = h * tl.num_programs(1) + m
|
||||
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
|
||||
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
|
||||
# extract information from LUT
|
||||
header = LUT + (hm // BLOCK_SIZE) * 2
|
||||
size = tl.load(header + 0)
|
||||
offset = tl.load(header + 1)
|
||||
# pointer offset
|
||||
off_a = z * stride_xz
|
||||
off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx
|
||||
off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx
|
||||
# do not need to read column indices in the dense case
|
||||
if IS_DENSE:
|
||||
ns = tl.arange(0, ROW_SIZE)
|
||||
else:
|
||||
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
|
||||
start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0)
|
||||
ns = start_n * BLOCK_SIZE + lane_n
|
||||
# load X
|
||||
mask = block_n < size
|
||||
a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf"))
|
||||
a = a.to(tl.float32)
|
||||
# compute
|
||||
out = a
|
||||
out *= scale
|
||||
# apply relative attention
|
||||
if R is not None:
|
||||
R += z * stride_zr
|
||||
R += h * stride_hr
|
||||
off_lo = (extent - m - 1) + ns
|
||||
mask_lo = (off_lo >= 0) & (off_lo < extent)
|
||||
rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0)
|
||||
out += rel_logits
|
||||
out = out.to(tl.float32)
|
||||
# apply causal mask
|
||||
out = tl.where((ns > m) & is_causal, -float("inf"), out)
|
||||
# computation
|
||||
out = tl.softmax(out)
|
||||
# write-back
|
||||
tl.store(Out + off_a + lane_n, out, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _blocksparse_softmax_bwd(
|
||||
DA, stride_zdx,
|
||||
DOut, stride_zdout,
|
||||
Out, stride_zout,
|
||||
scale,
|
||||
LUT,
|
||||
DR, extent, stride_zr, stride_hr, stride_er,
|
||||
is_causal,
|
||||
ROW_SIZE: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
IS_DENSE: tl.constexpr,
|
||||
):
|
||||
h = tl.program_id(0)
|
||||
m = tl.program_id(1)
|
||||
z = tl.program_id(2)
|
||||
# create index ranges
|
||||
hm = h * tl.num_programs(1) + m
|
||||
lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE
|
||||
block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE
|
||||
# extract information from LUT
|
||||
header = LUT + (hm // BLOCK_SIZE) * 2
|
||||
size = tl.load(header + 0)
|
||||
offset = tl.load(header + 1)
|
||||
# row-col offset
|
||||
off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE
|
||||
off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE
|
||||
mask = block_n < size
|
||||
# pointers
|
||||
As = Out + z * stride_zout + off_mn
|
||||
DOuts = DOut + z * stride_zdout + off_mn
|
||||
# do not need to read column indices in the dense case
|
||||
if IS_DENSE:
|
||||
ns = tl.arange(0, ROW_SIZE)
|
||||
else:
|
||||
off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE
|
||||
start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0)
|
||||
ns = start_n * BLOCK_SIZE + lane_n
|
||||
# load data
|
||||
a = tl.load(As + lane_n, mask=mask, other=0.0)
|
||||
a = a.to(tl.float32)
|
||||
dout = tl.load(DOuts + lane_n, mask=mask, other=0.0)
|
||||
dout = dout.to(tl.float32)
|
||||
# compute
|
||||
a = tl.where((ns > m) & is_causal & (a == a), 0., a)
|
||||
da = a * (dout - tl.sum(a * dout, 0))
|
||||
# apply relative attention
|
||||
if DR is not None:
|
||||
DR += z * stride_zr
|
||||
DR += h * stride_hr
|
||||
off_lo = (extent - m - 1) + ns
|
||||
mask_lo = (off_lo >= 0) & (off_lo < extent) & mask
|
||||
tl.store(DR + m * extent + off_lo, da, mask=mask_lo)
|
||||
da = da * scale
|
||||
# convert da
|
||||
# write-back
|
||||
DAs = DA + z * stride_zdx + off_mn
|
||||
tl.store(DAs + lane_n, da, mask=mask)
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def make_lut(layout, block, device):
|
||||
_empty = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
sizes = _empty.clone()
|
||||
# sizes along rows
|
||||
for h in range(layout.shape[0]):
|
||||
sizes = torch.cat((sizes, layout[h, :, :].sum(-1)))
|
||||
total_sizes = sizes * block
|
||||
# offsets in block format
|
||||
offsets = torch.zeros_like(sizes)
|
||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||
# block indices
|
||||
columns = layout.nonzero(as_tuple=False)[:, 2]
|
||||
header = torch.stack((sizes, offsets), dim=1).view(-1)
|
||||
lut = torch.cat((header, columns)).type(torch.int32).to(device)
|
||||
return lut, int(total_sizes.max())
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, scale, rel_logits, is_causal,
|
||||
spdims, block, lut, maxlut, is_dense
|
||||
):
|
||||
if scale is not None and isinstance(scale, torch.Tensor):
|
||||
assert scale.device.type == "cpu"
|
||||
scale = scale.item()
|
||||
M = a.shape[0]
|
||||
grid = [spdims[0], spdims[1] * block, M]
|
||||
rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape
|
||||
rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride()
|
||||
# enqueue kernel
|
||||
out = torch.empty_like(a)
|
||||
_blocksparse_softmax_fwd[grid](
|
||||
out, a, a.stride(0), lut,
|
||||
rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn
|
||||
scale,
|
||||
is_causal,
|
||||
BLOCK_SIZE=block,
|
||||
ROW_SIZE=triton.next_power_of_2(maxlut),
|
||||
IS_DENSE=is_dense,
|
||||
num_warps=num_warps(maxlut)
|
||||
)
|
||||
# save to context
|
||||
# ctx.mark_dirty(x)
|
||||
ctx.save_for_backward(out, lut)
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.maxlut = maxlut
|
||||
ctx.scale = scale
|
||||
ctx.rel_shape = rel_shape
|
||||
ctx.rel_strides = rel_strides
|
||||
ctx.rel_dtype = a.dtype
|
||||
ctx.is_dense = is_dense
|
||||
ctx.is_causal = is_causal
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout):
|
||||
# retrieve from context
|
||||
out, lut = ctx.saved_tensors
|
||||
# relative logits gradients
|
||||
dr = None
|
||||
if ctx.needs_input_grad[3]:
|
||||
dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device)
|
||||
# run kernel
|
||||
M = out.shape[0]
|
||||
grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M)
|
||||
da = torch.empty_like(dout)
|
||||
_blocksparse_softmax_bwd[grid](
|
||||
da, da.stride(0),
|
||||
dout, dout.stride(0),
|
||||
out, out.stride(0),
|
||||
ctx.scale,
|
||||
lut,
|
||||
dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2],
|
||||
ctx.is_causal,
|
||||
BLOCK_SIZE=ctx.block,
|
||||
ROW_SIZE=triton.next_power_of_2(ctx.maxlut),
|
||||
IS_DENSE=ctx.is_dense,
|
||||
num_warps=num_warps(ctx.maxlut)
|
||||
)
|
||||
return (da, None, None, dr, None,
|
||||
None, None, None, None, None,
|
||||
None,
|
||||
None, None, None,
|
||||
None,
|
||||
None, None, None
|
||||
)
|
||||
|
||||
|
||||
class softmax:
|
||||
def __init__(self, layout, block, device, is_dense=False):
|
||||
self.spdims = layout.shape
|
||||
self.layout = layout
|
||||
self.block = block
|
||||
self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device)
|
||||
self.is_dense = is_dense
|
||||
|
||||
def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False):
|
||||
if rel_logits is not None and rel_logits.dtype != a.dtype:
|
||||
raise ValueError(f"relative position embedding must be {a.dtype}")
|
||||
a = _softmax.apply(
|
||||
a, scale, rel_logits, is_causal,
|
||||
self.spdims, self.block, self.lut, self.maxlut, self.is_dense,
|
||||
)
|
||||
return a
|
||||
163
pkgs/triton/ops/bmm_matmul.py
Normal file
163
pkgs/triton/ops/bmm_matmul.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [1]:
|
||||
# TODO support block size 16 for MFMA dot op
|
||||
for block_m in [16, 32] if torch.version.hip is None and not hasattr(torch, "corex") else [32, 64]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 4 if block_n <= 64 else 8
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
#for split_k in [2, 4, 8, 16]:
|
||||
# configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
# num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
def get_configs_compute_bound():
|
||||
configs = []
|
||||
for block_m in [64, 128, 256]:
|
||||
for block_n in [64, 128, 256]:
|
||||
for block_k in [32, 64, 128]:
|
||||
num_warps = 8 if block_n <= 64 else 16
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=1, num_warps=num_warps))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
] + get_configs_compute_bound() + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % args['BLOCK_K'] == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _bmm_kernel(A, B, C, M, N, K,
|
||||
stride_aq, stride_am, stride_ak,
|
||||
stride_bq, stride_bk, stride_bn,
|
||||
stride_cq, stride_cm, stride_cn,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = tl.arange(0, BLOCK_K)
|
||||
|
||||
idx_q = tl.program_id(1) # batch dimension for BMM
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
|
||||
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
||||
for k in range(K, 0, -BLOCK_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
a = tl.load(A, mask=rk[None, :] < k, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * stride_ak
|
||||
B += BLOCK_K * stride_bk
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
idx_q = tl.program_id(1) # batch dimension for BMM
|
||||
idx_m = rm[:, None]
|
||||
idx_n = rn[None, :]
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn + idx_q * stride_cq)
|
||||
mask = (idx_m < M) & (idx_n < N)
|
||||
# handles write-back with reduction-splitting
|
||||
tl.store(C, acc, mask=mask)
|
||||
|
||||
class _bmm(torch.autograd.Function):
|
||||
kernel = _bmm_kernel
|
||||
|
||||
_locks = {}
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b, dot_out_dtype):
|
||||
device = a.device
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
|
||||
#only MR support Trans layout
|
||||
if hasattr(torch, "corex"):
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if (capability < 71):
|
||||
if a.stride(0) >= 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) >= 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[0] == b.shape[0], "incompatible dimensions"
|
||||
assert a.shape[2] == b.shape[1], "incompatible dimensions"
|
||||
B, M, K = a.shape
|
||||
_, _, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((B, M, N), device=device, dtype=a.dtype)
|
||||
if dot_out_dtype is None:
|
||||
if a.dtype in [torch.float16, torch.float32, torch.bfloat16]:
|
||||
dot_out_dtype = tl.float32
|
||||
else:
|
||||
dot_out_dtype = tl.int32
|
||||
else:
|
||||
assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype"
|
||||
if dot_out_dtype == torch.float16:
|
||||
dot_out_dtype = tl.float16
|
||||
elif dot_out_dtype in [torch.float32, torch.bfloat16]:
|
||||
dot_out_dtype = tl.float32
|
||||
else:
|
||||
dot_out_dtype = tl.int32
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), B, 1)
|
||||
_bmm_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1), a.stride(2),
|
||||
b.stride(0), b.stride(1), b.stride(2),
|
||||
c.stride(0), c.stride(1), c.stride(2),
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
GROUP_M=8)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, dot_out_dtype=None):
|
||||
return _bmm._call(a, b, dot_out_dtype=dot_out_dtype)
|
||||
|
||||
bmm = _bmm.apply
|
||||
94
pkgs/triton/ops/cross_entropy.py
Normal file
94
pkgs/triton/ops/cross_entropy.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def num_warps(N):
|
||||
if N < 2048:
|
||||
return 4
|
||||
elif N < 8192:
|
||||
return 8
|
||||
return 16
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
|
||||
@triton.jit
|
||||
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK)
|
||||
idx = tl.load(IDX + row)
|
||||
# pointers to logit and probs
|
||||
LOGITS = LOGITS + row * N + cols
|
||||
WRIT_PROBS = PROBS + row * N + cols
|
||||
READ_PROBS = PROBS + row * N + idx
|
||||
# write-back negative log-probs
|
||||
logits = tl.load(LOGITS, mask=cols < N, other=-float('inf'))
|
||||
logits = logits.to(tl.float32)
|
||||
logits = logits - tl.max(logits, 0)
|
||||
probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits
|
||||
tl.store(WRIT_PROBS, probs, mask=cols < N)
|
||||
# There is a bug in the compiler, which fails to insert a barrier here.
|
||||
# We add it explicitly for now. Will be fixed soon.
|
||||
tl.debug_barrier()
|
||||
# write-back loss
|
||||
probs = tl.load(READ_PROBS)
|
||||
tl.store(LOSS + row, probs)
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
|
||||
@triton.jit
|
||||
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
||||
row = tl.program_id(0)
|
||||
cols = tl.arange(0, BLOCK)
|
||||
idx = tl.load(IDX + row)
|
||||
# pointers to probs
|
||||
PROBS = PROBS + row * N + cols
|
||||
# We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
||||
# and we have -log(p[k]) stored in PROBS, so this is easy
|
||||
probs = -tl.load(PROBS, mask=cols < N, other=float('inf'))
|
||||
probs = tl.exp(probs.to(tl.float32))
|
||||
delta = cols == idx
|
||||
# write result in-place in PROBS
|
||||
dout = tl.load(DPROBS + row)
|
||||
din = (probs - delta) * dout
|
||||
tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N)
|
||||
|
||||
|
||||
class _cross_entropy(torch.autograd.Function):
|
||||
@classmethod
|
||||
def forward(cls, ctx, logits, indices):
|
||||
# make sure we can use triton
|
||||
assert (indices.dtype == torch.int64), "Indices are expected to be of type long."
|
||||
# make kernel
|
||||
device, dtype = logits.device, logits.dtype
|
||||
n_cols = logits.shape[-1]
|
||||
# run the kernel
|
||||
result = torch.empty_like(indices, dtype=dtype, device=device)
|
||||
neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device)
|
||||
grid = lambda opt: (logits.numel() // n_cols, )
|
||||
_forward[grid](logits, neg_logprobs, indices, result, n_cols)
|
||||
# save for backward
|
||||
ctx.save_for_backward(neg_logprobs, indices)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def backward(cls, ctx, dneg_logprobs):
|
||||
"""We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k]
|
||||
so we initialize the gradient as neg_logprobs, so we can just exponentiate
|
||||
to get p[k], which is most of what we need... neg_logprobs will be
|
||||
modified in place to become the gradient we want
|
||||
"""
|
||||
# load saved tensors
|
||||
neg_logprobs, indices = ctx.saved_tensors
|
||||
# run the kernel
|
||||
# neg_logprobs will be modified in place to become our gradient:
|
||||
n_cols = neg_logprobs.shape[-1]
|
||||
grid = lambda opt: (neg_logprobs.numel() // n_cols, )
|
||||
_backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols)
|
||||
return neg_logprobs, None
|
||||
|
||||
|
||||
cross_entropy = _cross_entropy.apply
|
||||
271
pkgs/triton/ops/flash_attention.py
Normal file
271
pkgs/triton/ops/flash_attention.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.common.build import is_corex
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L, M,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_ptrs)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
# compute new m
|
||||
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
|
||||
# correct old l
|
||||
l_prev *= tl.exp(m_prev - m_curr)
|
||||
# attention weights
|
||||
p = tl.exp(qk - m_curr[:, None])
|
||||
l_curr = tl.sum(p, 1) + l_prev
|
||||
# rescale operands of matmuls
|
||||
l_rcp = 1. / l_curr
|
||||
p *= l_rcp[:, None]
|
||||
acc *= (l_prev * l_rcp)[:, None]
|
||||
# update acc
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
v = tl.load(v_ptrs)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_prev = l_curr
|
||||
m_prev = m_curr
|
||||
# update pointers
|
||||
k_ptrs += BLOCK_N * stride_kn
|
||||
v_ptrs += BLOCK_N * stride_vk
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_prev)
|
||||
tl.store(m_ptrs, m_prev)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||
out_ptrs = Out + off_o
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
denom = tl.load(L + off_m).to(tl.float32)
|
||||
# compute
|
||||
do = do / denom[:, None]
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_qz + off_h * stride_qh
|
||||
V += off_z * stride_qz + off_h * stride_qh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_qz + off_h * stride_qh
|
||||
DV += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
lo = start_n * BLOCK_M
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_m = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
m_ptrs = M + off_hz * N_CTX
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
# loop over rows
|
||||
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
qk = tl.dot(q, tl.trans(k))
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, tl.trans(v))
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
|
||||
# compute dq
|
||||
dq = tl.load(dq_ptrs)
|
||||
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
|
||||
tl.store(dq_ptrs, dq)
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_qm
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, sm_scale):
|
||||
# only support for Ampere now
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if not is_corex():
|
||||
if capability[0] < 8:
|
||||
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64 # FIXME: currently BLOCK=128 has issues, BLOCK=64 works for common cases.
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=Lk, num_warps=num_warps,
|
||||
num_stages=2 if not is_corex() else 1,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
BLOCK = 128 if not is_corex() else 64 # FIXME: currently BLOCK=128 has issues, BLOCK=64 works for common cases.
|
||||
num_warps = 16 if is_corex() and ctx.BLOCK_DMODEL > 64 else 8
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
184
pkgs/triton/ops/matmul.py
Normal file
184
pkgs/triton/ops/matmul.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
if hasattr(torch, "corex"):
|
||||
return configs
|
||||
for num_stages in [1, 2]:
|
||||
# TODO support block size 16 for MFMA dot op
|
||||
for block_m in [16, 32] if torch.version.hip is None and not hasattr(torch, "corex") else [32, 64]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
def get_configs_compute_bound():
|
||||
configs = []
|
||||
if hasattr(torch, "corex"):
|
||||
for block_m in [32, 64, 128, 256]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for block_k in [32, 64, 128, 256]:
|
||||
for num_stages in [1, 2]:
|
||||
num_warps = 16 if block_m >= 128 or block_n >=128 or block_k >= 128 else 8
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
return configs
|
||||
|
||||
def get_nv_config():
|
||||
configs = []
|
||||
if hasattr(torch, "corex"):
|
||||
return configs
|
||||
configs = [# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
]
|
||||
return configs
|
||||
|
||||
@triton.autotune(
|
||||
configs=get_nv_config() + get_configs_compute_bound() + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _kernel(A, B, C, M, N, K,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
dot_out_dtype: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||
acc += tl.dot(a, b, out_dtype=dot_out_dtype)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
|
||||
class _matmul(torch.autograd.Function):
|
||||
kernel = _kernel
|
||||
|
||||
_locks = {}
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b, dot_out_dtype):
|
||||
device = a.device
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=a.dtype)
|
||||
if dot_out_dtype is None:
|
||||
if a.dtype in [torch.float16, torch.float32, torch.bfloat16]:
|
||||
dot_out_dtype = tl.float32
|
||||
else:
|
||||
dot_out_dtype = tl.int32
|
||||
else:
|
||||
assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype"
|
||||
if dot_out_dtype == torch.float16:
|
||||
dot_out_dtype = tl.float16
|
||||
elif dot_out_dtype in [torch.float32, torch.bfloat16]:
|
||||
dot_out_dtype = tl.float32
|
||||
else:
|
||||
dot_out_dtype = tl.int32
|
||||
# launch kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_kernel[grid](a, b, c, M, N, K,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
dot_out_dtype=dot_out_dtype,
|
||||
GROUP_M=8)
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, dot_out_dtype=None):
|
||||
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype)
|
||||
|
||||
|
||||
matmul = _matmul.apply
|
||||
164
pkgs/triton/ops/matmul_perf_model.py
Normal file
164
pkgs/triton/ops/matmul_perf_model.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import heapq
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from triton.runtime import driver
|
||||
from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops
|
||||
|
||||
|
||||
def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
def get_simd_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
''' return compute throughput in TOPS '''
|
||||
total_warps = num_ctas * min(num_warps, 4)
|
||||
num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs
|
||||
tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device)
|
||||
return tflops
|
||||
|
||||
|
||||
def get_tflops(backend, device, num_ctas, num_warps, dtype):
|
||||
capability = torch.cuda.get_device_capability(device)
|
||||
if capability[0] < 8 and dtype == torch.float32:
|
||||
return get_simd_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
|
||||
|
||||
def estimate_matmul_time(
|
||||
# backend, device,
|
||||
num_warps, num_stages,
|
||||
A, B, C,
|
||||
M, N, K,
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K,
|
||||
debug=False, **kwargs
|
||||
):
|
||||
''' return estimated running time in ms
|
||||
= max(compute, loading) + store '''
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
device = torch.cuda.current_device()
|
||||
dtype = A.dtype
|
||||
dtsize = A.element_size()
|
||||
|
||||
num_cta_m = triton.cdiv(M, BLOCK_M)
|
||||
num_cta_n = triton.cdiv(N, BLOCK_N)
|
||||
num_cta_k = SPLIT_K
|
||||
num_ctas = num_cta_m * num_cta_n * num_cta_k
|
||||
|
||||
# If the input is smaller than the block size
|
||||
M, N = max(M, BLOCK_M), max(N, BLOCK_N)
|
||||
|
||||
# time to compute
|
||||
total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS
|
||||
tput = get_tflops(backend, device, num_ctas, num_warps, dtype)
|
||||
compute_ms = total_ops / tput
|
||||
|
||||
# time to load data
|
||||
num_sm = driver.utils.get_device_properties(device)["multiprocessor_count"]
|
||||
active_cta_ratio = min(1, num_ctas / num_sm)
|
||||
active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate
|
||||
active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5%
|
||||
dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s
|
||||
l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?)
|
||||
# assume 80% of (following) loads are in L2 cache
|
||||
load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1))
|
||||
load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1)
|
||||
load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1))
|
||||
load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1)
|
||||
# total
|
||||
total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB
|
||||
total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024)
|
||||
# loading time in ms
|
||||
load_ms = total_dram / dram_bw + total_l2 / l2_bw
|
||||
|
||||
# estimate storing time
|
||||
store_bw = dram_bw * 0.6 # :o
|
||||
store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB
|
||||
if SPLIT_K == 1:
|
||||
store_ms = store_c_dram / store_bw
|
||||
else:
|
||||
reduce_bw = store_bw
|
||||
store_ms = store_c_dram / reduce_bw
|
||||
# c.zero_()
|
||||
zero_ms = M * N * 2 / (1024 * 1024) / store_bw
|
||||
store_ms += zero_ms
|
||||
|
||||
total_time_ms = compute_ms + load_ms + store_ms
|
||||
if debug:
|
||||
print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, '
|
||||
f'loading time: {load_ms}ms, store time: {store_ms}ms, '
|
||||
f'Activate CTAs: {active_cta_ratio*100}%')
|
||||
return total_time_ms
|
||||
|
||||
|
||||
def early_config_prune(configs, named_args):
|
||||
device = torch.cuda.current_device()
|
||||
capability = torch.cuda.get_device_capability()
|
||||
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
|
||||
dtsize = named_args['A'].element_size()
|
||||
dtype = named_args['A'].dtype
|
||||
|
||||
# 1. make sure we have enough smem
|
||||
pruned_configs = []
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
|
||||
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages
|
||||
|
||||
max_shared_memory = driver.utils.get_device_properties(device)["max_shared_mem"]
|
||||
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||
if required_shared_memory <= max_shared_memory:
|
||||
pruned_configs.append(config)
|
||||
configs = pruned_configs
|
||||
|
||||
# Some dtypes do not allow atomic_add
|
||||
if dtype not in [torch.float16, torch.float32]:
|
||||
configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1]
|
||||
|
||||
# group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps)
|
||||
configs_map = {}
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \
|
||||
kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages
|
||||
|
||||
key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps)
|
||||
if key in configs_map:
|
||||
configs_map[key].append((config, num_stages))
|
||||
else:
|
||||
configs_map[key] = [(config, num_stages)]
|
||||
|
||||
pruned_configs = []
|
||||
for k, v in configs_map.items():
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k
|
||||
if capability[0] >= 8:
|
||||
# compute cycles (only works for ampere GPUs)
|
||||
mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16)
|
||||
mma_cycles = mmas / min(4, num_warps) * 8
|
||||
|
||||
ldgsts_latency = 300 # Does this matter?
|
||||
optimal_num_stages = ldgsts_latency / mma_cycles
|
||||
|
||||
# nearest stages, prefer large #stages
|
||||
nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages)
|
||||
if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages)
|
||||
|
||||
for n in nearest:
|
||||
pruned_configs.append(n[0])
|
||||
else: # Volta & Turing only supports num_stages <= 2
|
||||
if hasattr(torch, "corex"):
|
||||
for stage in range(len(v)):
|
||||
random_config = v[stage][0]
|
||||
random_config.num_stages = v[stage][1]
|
||||
pruned_configs.append(random_config)
|
||||
else:
|
||||
random_config = v[0][0]
|
||||
random_config.num_stages = 2
|
||||
pruned_configs.append(random_config)
|
||||
return pruned_configs
|
||||
Reference in New Issue
Block a user