From 5f1a485d9e27341453d60389474d96c444c7d8bd Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 17 Feb 2025 18:01:21 +0800 Subject: [PATCH] =?UTF-8?q?Revert=20"[ROCm]=20Use=20`tl.range()`=20in=20bl?= =?UTF-8?q?ock=20GEMM=20kernels=20with=20`num=5Fstage=E2=80=A6=20(#3632)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../srt/layers/quantization/fp8_kernel.py | 107 +----------------- 1 file changed, 6 insertions(+), 101 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 3dc20467f..47f310a24 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -272,7 +272,6 @@ def _w8a8_block_fp8_matmul( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - num_stages: tl.constexpr, ): """Triton-accelerated function used to perform linear operations (dot product) on input tensors `A` and `B` with block-wise quantization, and store the result in output @@ -358,7 +357,6 @@ def _w8a8_block_fp8_matmul_unrolledx4( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - num_stages: tl.constexpr, ): """Triton-accelerated function used to perform linear operations (dot product) on input tensors `A` and `B` with block-wise quantization, and store the result in output @@ -388,9 +386,7 @@ def _w8a8_block_fp8_matmul_unrolledx4( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # manually unroll to 4 iterations UNROLL_FACTOR = 4 - for k in tl.range( - 0, tl.cdiv(K, BLOCK_SIZE_K * UNROLL_FACTOR), num_stages=num_stages - ): + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * UNROLL_FACTOR)): # 1st iteration a = tl.load( a_ptrs, @@ -489,92 +485,6 @@ def _w8a8_block_fp8_matmul_unrolledx4( tl.store(c_ptrs, c, mask=c_mask) -@triton.jit -def _w8a8_block_fp8_matmul_hip( - # Pointers to inputs and output - A, - B, - C, - As, - Bs, - # Shape for matmul - M, - N, - K, - # Block size for block-wise quantization - group_n, - group_k, - # Stride for inputs and output - stride_am, - stride_ak, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_As_m, - stride_As_k, - stride_Bs_k, - stride_Bs_n, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - num_stages: tl.constexpr, -): - """Triton-accelerated function used to perform linear operations (dot - product) on input tensors `A` and `B` with block-wise quantization, and store the result in output - tensor `C`. - """ - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - As_ptrs = As + offs_am * stride_As_m - offs_bsn = offs_bn // group_n - Bs_ptrs = Bs + offs_bsn * stride_Bs_n - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - - k_start = k * BLOCK_SIZE_K - offs_ks = k_start // group_k - a_s = tl.load(As_ptrs + offs_ks * stride_As_k) - b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) - - accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - if C.dtype.element_ty == tl.bfloat16: - c = accumulator.to(tl.bfloat16) - elif C.dtype.element_ty == tl.float16: - c = accumulator.to(tl.float16) - else: - c = accumulator.to(tl.float32) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - - @functools.lru_cache def get_w8a8_block_fp8_configs( N: int, K: int, block_n: int, block_k: int @@ -685,16 +595,11 @@ def w8a8_block_fp8_matmul( num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( N, config["BLOCK_SIZE_N"] ) - - kernel = _w8a8_block_fp8_matmul - - # On AMD GPU, use kernels where software pipelining with num_stages is - # explicitly specified in the hot loop. - if is_hip_ == True: - if num_workgroups <= get_device_core_count(): - kernel = _w8a8_block_fp8_matmul_unrolledx4 - else: - kernel = _w8a8_block_fp8_matmul_hip + kernel = ( + _w8a8_block_fp8_matmul_unrolledx4 + if (is_hip_ == True and num_workgroups <= get_device_core_count()) + else _w8a8_block_fp8_matmul + ) kernel[grid]( A,