[ROCm] Fix fp8 unrolledx4 matmul kernel. (#3325)
Co-authored-by: HAI <hixiao@gmail.com>
This commit is contained in:
committed by
GitHub
parent
2d9c319594
commit
32de54ed1a
@@ -279,12 +279,21 @@ def _w8a8_block_fp8_matmul_unrolledx4(
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
# manually unroll to 4 iterations
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K) // 4):
|
||||
UNROLL_FACTOR = 4
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * UNROLL_FACTOR)):
|
||||
# 1st iteration
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K,
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
k_start = (k * UNROLL_FACTOR) * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
||||
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
||||
@@ -294,8 +303,16 @@ def _w8a8_block_fp8_matmul_unrolledx4(
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
# 2nd iteration
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K,
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
k_start = k_start + BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
@@ -307,8 +324,16 @@ def _w8a8_block_fp8_matmul_unrolledx4(
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
# 3rd iteration
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K,
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
k_start = k_start + BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
@@ -320,8 +345,16 @@ def _w8a8_block_fp8_matmul_unrolledx4(
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
# 4th iteration
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K,
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K,
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
k_start = k_start + BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
|
||||
Reference in New Issue
Block a user