From 32de54ed1a7b14bf7f54a61f9bab8c618c224449 Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Wed, 5 Feb 2025 20:44:15 -0600 Subject: [PATCH] [ROCm] Fix fp8 unrolledx4 matmul kernel. (#3325) Co-authored-by: HAI --- .../srt/layers/quantization/fp8_kernel.py | 53 +++++++++++++++---- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index ddd614fdf..28c371cfe 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -279,12 +279,21 @@ def _w8a8_block_fp8_matmul_unrolledx4( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # manually unroll to 4 iterations - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K) // 4): + UNROLL_FACTOR = 4 + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * UNROLL_FACTOR)): # 1st iteration - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K, + other=0.0, + ) - k_start = k * BLOCK_SIZE_K + k_start = (k * UNROLL_FACTOR) * BLOCK_SIZE_K offs_ks = k_start // group_k a_s = tl.load(As_ptrs + offs_ks * stride_As_k) b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) @@ -294,8 +303,16 @@ def _w8a8_block_fp8_matmul_unrolledx4( b_ptrs += BLOCK_SIZE_K * stride_bk # 2nd iteration - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K, + other=0.0, + ) k_start = k_start + BLOCK_SIZE_K offs_ks = k_start // group_k @@ -307,8 +324,16 @@ def _w8a8_block_fp8_matmul_unrolledx4( b_ptrs += BLOCK_SIZE_K * stride_bk # 3rd iteration - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K, + other=0.0, + ) k_start = k_start + BLOCK_SIZE_K offs_ks = k_start // group_k @@ -320,8 +345,16 @@ def _w8a8_block_fp8_matmul_unrolledx4( b_ptrs += BLOCK_SIZE_K * stride_bk # 4th iteration - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K, + other=0.0, + ) k_start = k_start + BLOCK_SIZE_K offs_ks = k_start // group_k