diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index ddd614fdf..28c371cfe 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -279,12 +279,21 @@ def _w8a8_block_fp8_matmul_unrolledx4( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # manually unroll to 4 iterations - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K) // 4): + UNROLL_FACTOR = 4 + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * UNROLL_FACTOR)): # 1st iteration - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K, + other=0.0, + ) - k_start = k * BLOCK_SIZE_K + k_start = (k * UNROLL_FACTOR) * BLOCK_SIZE_K offs_ks = k_start // group_k a_s = tl.load(As_ptrs + offs_ks * stride_As_k) b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) @@ -294,8 +303,16 @@ def _w8a8_block_fp8_matmul_unrolledx4( b_ptrs += BLOCK_SIZE_K * stride_bk # 2nd iteration - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K, + other=0.0, + ) k_start = k_start + BLOCK_SIZE_K offs_ks = k_start // group_k @@ -307,8 +324,16 @@ def _w8a8_block_fp8_matmul_unrolledx4( b_ptrs += BLOCK_SIZE_K * stride_bk # 3rd iteration - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K, + other=0.0, + ) k_start = k_start + BLOCK_SIZE_K offs_ks = k_start // group_k @@ -320,8 +345,16 @@ def _w8a8_block_fp8_matmul_unrolledx4( b_ptrs += BLOCK_SIZE_K * stride_bk # 4th iteration - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K, + other=0.0, + ) k_start = k_start + BLOCK_SIZE_K offs_ks = k_start // group_k