From 03caefeb5117b0e3c29468216b82eae50a45e3b6 Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Sun, 16 Feb 2025 03:40:38 -0600 Subject: [PATCH] [ROCm] Use `tl.range()` in block GEMM kernels with `num_stages` set by host. (#3535) Co-authored-by: HAI --- .../srt/layers/quantization/fp8_kernel.py | 107 +++++++++++++++++- 1 file changed, 101 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 47f310a24..3dc20467f 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -272,6 +272,7 @@ def _w8a8_block_fp8_matmul( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + num_stages: tl.constexpr, ): """Triton-accelerated function used to perform linear operations (dot product) on input tensors `A` and `B` with block-wise quantization, and store the result in output @@ -357,6 +358,7 @@ def _w8a8_block_fp8_matmul_unrolledx4( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + num_stages: tl.constexpr, ): """Triton-accelerated function used to perform linear operations (dot product) on input tensors `A` and `B` with block-wise quantization, and store the result in output @@ -386,7 +388,9 @@ def _w8a8_block_fp8_matmul_unrolledx4( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # manually unroll to 4 iterations UNROLL_FACTOR = 4 - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * UNROLL_FACTOR)): + for k in tl.range( + 0, tl.cdiv(K, BLOCK_SIZE_K * UNROLL_FACTOR), num_stages=num_stages + ): # 1st iteration a = tl.load( a_ptrs, @@ -485,6 +489,92 @@ def _w8a8_block_fp8_matmul_unrolledx4( tl.store(c_ptrs, c, mask=c_mask) +@triton.jit +def _w8a8_block_fp8_matmul_hip( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + num_stages: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and store the result in output + tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + @functools.lru_cache def get_w8a8_block_fp8_configs( N: int, K: int, block_n: int, block_k: int @@ -595,11 +685,16 @@ def w8a8_block_fp8_matmul( num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( N, config["BLOCK_SIZE_N"] ) - kernel = ( - _w8a8_block_fp8_matmul_unrolledx4 - if (is_hip_ == True and num_workgroups <= get_device_core_count()) - else _w8a8_block_fp8_matmul - ) + + kernel = _w8a8_block_fp8_matmul + + # On AMD GPU, use kernels where software pipelining with num_stages is + # explicitly specified in the hot loop. + if is_hip_ == True: + if num_workgroups <= get_device_core_count(): + kernel = _w8a8_block_fp8_matmul_unrolledx4 + else: + kernel = _w8a8_block_fp8_matmul_hip kernel[grid]( A,