[ROCm] Manually unroll _w8a8_block_fp8_matmul kernel on AMD GPU. (#3299)
This commit is contained in:
committed by
GitHub
parent
c7256ca836
commit
c2723a42a5
@@ -220,6 +220,132 @@ def _w8a8_block_fp8_matmul(
|
|||||||
tl.store(c_ptrs, c, mask=c_mask)
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _w8a8_block_fp8_matmul_unrolledx4(
|
||||||
|
# Pointers to inputs and output
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
# Shape for matmul
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
# Block size for block-wise quantization
|
||||||
|
group_n,
|
||||||
|
group_k,
|
||||||
|
# Stride for inputs and output
|
||||||
|
stride_am,
|
||||||
|
stride_ak,
|
||||||
|
stride_bk,
|
||||||
|
stride_bn,
|
||||||
|
stride_cm,
|
||||||
|
stride_cn,
|
||||||
|
stride_As_m,
|
||||||
|
stride_As_k,
|
||||||
|
stride_Bs_k,
|
||||||
|
stride_Bs_n,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Triton-accelerated function used to perform linear operations (dot
|
||||||
|
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
|
||||||
|
tensor `C`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
group_id = pid // num_pid_in_group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
pid_m = first_pid_m + (pid % group_size_m)
|
||||||
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
|
|
||||||
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||||
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||||
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||||
|
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||||
|
|
||||||
|
As_ptrs = As + offs_am * stride_As_m
|
||||||
|
offs_bsn = offs_bn // group_n
|
||||||
|
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
||||||
|
|
||||||
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
# manually unroll to 4 iterations
|
||||||
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K) // 4):
|
||||||
|
# 1st iteration
|
||||||
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||||
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||||
|
|
||||||
|
k_start = k * BLOCK_SIZE_K
|
||||||
|
offs_ks = k_start // group_k
|
||||||
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
||||||
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
||||||
|
|
||||||
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||||
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
|
# 2nd iteration
|
||||||
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||||
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||||
|
|
||||||
|
k_start = k_start + BLOCK_SIZE_K
|
||||||
|
offs_ks = k_start // group_k
|
||||||
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
||||||
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
||||||
|
|
||||||
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||||
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
|
# 3rd iteration
|
||||||
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||||
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||||
|
|
||||||
|
k_start = k_start + BLOCK_SIZE_K
|
||||||
|
offs_ks = k_start // group_k
|
||||||
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
||||||
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
||||||
|
|
||||||
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||||
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
|
# 4th iteration
|
||||||
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||||
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||||
|
|
||||||
|
k_start = k_start + BLOCK_SIZE_K
|
||||||
|
offs_ks = k_start // group_k
|
||||||
|
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
||||||
|
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
||||||
|
|
||||||
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||||
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
|
if C.dtype.element_ty == tl.bfloat16:
|
||||||
|
c = accumulator.to(tl.bfloat16)
|
||||||
|
elif C.dtype.element_ty == tl.float16:
|
||||||
|
c = accumulator.to(tl.float16)
|
||||||
|
else:
|
||||||
|
c = accumulator.to(tl.float32)
|
||||||
|
|
||||||
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||||
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||||
|
tl.store(c_ptrs, c, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def get_w8a8_block_fp8_configs(
|
def get_w8a8_block_fp8_configs(
|
||||||
N: int, K: int, block_n: int, block_k: int
|
N: int, K: int, block_n: int, block_k: int
|
||||||
@@ -324,7 +450,12 @@ def w8a8_block_fp8_matmul(
|
|||||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
_w8a8_block_fp8_matmul[grid](
|
# Use manually unrolledx4 kernel on AMD GPU.
|
||||||
|
kernel = (
|
||||||
|
_w8a8_block_fp8_matmul_unrolledx4 if is_hip_ == True else _w8a8_block_fp8_matmul
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel[grid](
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
|
|||||||
Reference in New Issue
Block a user