Files
enginex-bi_series-vllm/pkgs/triton/ops/matmul.py
2025-08-05 19:02:46 +08:00

185 lines
8.2 KiB
Python

import torch
import triton
import triton.language as tl
from .matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
if hasattr(torch, "corex"):
return configs
for num_stages in [1, 2]:
# TODO support block size 16 for MFMA dot op
for block_m in [16, 32] if torch.version.hip is None and not hasattr(torch, "corex") else [32, 64]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
def get_configs_compute_bound():
configs = []
if hasattr(torch, "corex"):
for block_m in [32, 64, 128, 256]:
for block_n in [32, 64, 128, 256]:
for block_k in [32, 64, 128, 256]:
for num_stages in [1, 2]:
num_warps = 16 if block_m >= 128 or block_n >=128 or block_k >= 128 else 8
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
return configs
def get_nv_config():
configs = []
if hasattr(torch, "corex"):
return configs
configs = [# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
]
return configs
@triton.autotune(
configs=get_nv_config() + get_configs_compute_bound() + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel(A, B, C, M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
dot_out_dtype: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b, out_dtype=dot_out_dtype)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
class _matmul(torch.autograd.Function):
kernel = _kernel
_locks = {}
@staticmethod
def _call(a, b, dot_out_dtype):
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=a.dtype)
if dot_out_dtype is None:
if a.dtype in [torch.float16, torch.float32, torch.bfloat16]:
dot_out_dtype = tl.float32
else:
dot_out_dtype = tl.int32
else:
assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype"
if dot_out_dtype == torch.float16:
dot_out_dtype = tl.float16
elif dot_out_dtype in [torch.float32, torch.bfloat16]:
dot_out_dtype = tl.float32
else:
dot_out_dtype = tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
dot_out_dtype=dot_out_dtype,
GROUP_M=8)
return c
@staticmethod
def forward(ctx, a, b, dot_out_dtype=None):
return _matmul._call(a, b, dot_out_dtype=dot_out_dtype)
matmul = _matmul.apply