From 930fe467bdc99e54c61dd63e69f5aa0a817b15e3 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Tue, 12 Aug 2025 21:21:55 -0700 Subject: [PATCH] Support Triton FP8 Gemm can handle hidden_dim not divisible by 16 (#9093) Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> --- .../srt/layers/quantization/fp8_kernel.py | 218 ++++++++++++++++++ .../srt/layers/quantization/fp8_utils.py | 26 ++- test/srt/quant/test_triton_scaled_mm.py | 94 ++++++++ test/srt/run_suite.py | 1 + 4 files changed, 332 insertions(+), 7 deletions(-) create mode 100644 test/srt/quant/test_triton_scaled_mm.py diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index c3be57649..e9df65a15 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -1415,3 +1415,221 @@ def per_group_transpose( a, trans_a, expert_offsets, k, M_ALIGNMENT, BLOCK_SIZE_M=16, BLOCK_SIZE_K=8 ) return trans_a + + +def is_weak_contiguous(x: torch.Tensor): + strides = x.stride() + sizes = x.shape + is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0])) + is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1])) + return is_transpose or is_not_transpose + + +@triton.jit +def scaled_mm_kernel( + a_ptr, + b_ptr, + scale_a_ptr, + scale_b_ptr, + c_ptr, + bias_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + ACCUMULATOR_DTYPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_SCALE_A: tl.constexpr, + BLOCK_SIZE_SCALE_B: tl.constexpr, +): + pid = tl.program_id(axis=0) + + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = ACCUMULATOR_DTYPE + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) + + # NOTE: Some tensor inputs are so large, they will cause int32 overflow + # so it is necessary to use tl.int64 for all the offsets, else SEGV will + # eventually occur. + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + masks_am = offsets_am < M + + offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + masks_bn = offsets_bn < N + + offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :] + offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :] + + # NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create + # appropriate offsets and masks for each case. Same goes for + # BLOCK_SIZE_SCALE_B. + offsets_scale_am = ( + tl.arange(0, BLOCK_SIZE_SCALE_A) + + (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M + ) + masks_scale_am = offsets_scale_am < M + + offsets_scale_bn = ( + tl.arange(0, BLOCK_SIZE_SCALE_B) + + (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N + ) + masks_scale_bn = offsets_scale_bn < N + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + scale_a_ptrs = scale_a_ptr + offsets_scale_am + scale_b_ptrs = scale_b_ptr + offsets_scale_bn + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Apply scale at end. + masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None] + scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a) + # Need to broadcast to the appropriate size, if scale_a is already + # (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes + # for scale_b below. + scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1)) + accumulator = scale_a * accumulator.to(tl.float32) + + masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :] + scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b) + scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1)) + accumulator = scale_b.T * accumulator.to(tl.float32) + + # Convert to output format. + c = accumulator.to(c_ptr.type.element_ty) + + # Add bias, it's already in output format, so add it after conversion. + if bias_ptr: + offsets_bias = offsets_bn + bias_ptrs = bias_ptr + offsets_bias + bias_mask = offsets_bias < N + bias = tl.load(bias_ptrs, bias_mask) + c += bias + + # Save output + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) + offs_cm = offs_cm.to(tl.int64) + offs_cn = offs_cn.to(tl.int64) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + tl.store(c_ptrs, c, mask=c_mask) + + +# input - [M, K] +# weight - [K, N] +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +def triton_scaled_mm( + input: torch.Tensor, + weight: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: type[torch.dtype], + bias: Optional[torch.Tensor] = None, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32, + use_heuristic=True, +) -> torch.Tensor: + M, K = input.shape + N = weight.shape[1] + + assert N > 0 and K > 0 and M > 0 + assert weight.shape[0] == K + assert input.dtype == weight.dtype + + scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a + scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b + + assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() + assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M) + assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N) + assert out_dtype.is_floating_point + assert bias is None or bias.is_floating_point() + assert is_weak_contiguous(input) + assert is_weak_contiguous(weight) + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + result = torch.empty((M, N), dtype=out_dtype, device=input.device) + + has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1 + + if use_heuristic: + is_small_N = N < 8192 + next_power_of_2_M = max(32, triton.next_power_of_2(M)) + if next_power_of_2_M <= 32: + tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256) + elif next_power_of_2_M <= 64: + tile_shape = (64, 64, 256) + elif next_power_of_2_M <= 128: + tile_shape = (64, 128, 128) + else: + tile_shape = (128, 128, 128) + + block_size_m, block_size_n, block_size_k = tile_shape + + block_size_sa = 1 if has_scalar(scale_a) else block_size_m + block_size_sb = 1 if has_scalar(scale_b) else block_size_n + + accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32 + + # A = input, B = weight, C = result + # A = M x K, B = K x N, C = M x N + scaled_mm_kernel[grid]( + input, + weight, + scale_a, + scale_b, + result, + bias, + M, + N, + K, + input.stride(0), + input.stride(1), + weight.stride(0), + weight.stride(1), + result.stride(0), + result.stride(1), + accumulator_dtype, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + BLOCK_SIZE_SCALE_A=block_size_sa, + BLOCK_SIZE_SCALE_B=block_size_sb, + ) + + return result.to(out_dtype) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 989056b37..f97a574c8 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( scaled_fp8_quant, sglang_per_token_quant_fp8, static_quant_fp8, + triton_scaled_mm, w8a8_block_fp8_matmul_deepgemm, w8a8_block_fp8_matmul_triton, ) @@ -586,14 +587,25 @@ def apply_fp8_linear( assert ( weight_scale.numel() == weight.shape[1] ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale" - output = fp8_scaled_mm( - qinput, - weight, - x_scale, - weight_scale, - out_dtype=input.dtype, - bias=bias, + + cutlass_compatible_b = ( + weight.shape[0] % 16 == 0 and weight.shape[1] % 16 == 0 ) + if not cutlass_compatible_b: + # Massage the input to be 2D + qinput = qinput.view(-1, qinput.shape[-1]) + output = triton_scaled_mm( + qinput, weight, x_scale, weight_scale, input.dtype, bias + ) + else: + output = fp8_scaled_mm( + qinput, + weight, + x_scale, + weight_scale, + out_dtype=input.dtype, + bias=bias, + ) return output.view(*output_shape) # torch.scaled_mm supports per tensor weights + activations only diff --git a/test/srt/quant/test_triton_scaled_mm.py b/test/srt/quant/test_triton_scaled_mm.py new file mode 100644 index 000000000..dafde83be --- /dev/null +++ b/test/srt/quant/test_triton_scaled_mm.py @@ -0,0 +1,94 @@ +import itertools +import unittest +from typing import Optional + +import torch +import torch.testing + +from sglang.srt.layers.quantization.fp8_kernel import triton_scaled_mm +from sglang.test.test_utils import CustomTestCase + + +def torch_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Reference implementation using float32 for stability""" + out = torch.mm(a.to(torch.float32), b.to(torch.float32)) + out = scale_a.to(torch.float32) * out * scale_b.to(torch.float32).T + if bias is not None: + out = out + bias.to(torch.float32) + return out.to(out_dtype) + + +class TestScaledMM(CustomTestCase): + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("This test requires a CUDA device.") + torch.set_default_device("cuda") + + def _make_inputs(self, M, K, N, in_dtype): + if in_dtype == torch.int8: + a = torch.randint(-8, 8, (M, K), dtype=in_dtype, device="cuda") + b = torch.randint(-8, 8, (K, N), dtype=in_dtype, device="cuda") + else: # fp8 + a = torch.clamp( + 0.1 * torch.randn((M, K), dtype=torch.float16, device="cuda"), -0.3, 0.3 + ).to(in_dtype) + b = torch.clamp( + 0.1 * torch.randn((K, N), dtype=torch.float16, device="cuda"), -0.3, 0.3 + ).to(in_dtype) + return a, b + + def test_basic_cases(self): + """Test core functionality with reduced precision requirements""" + test_configs = [ + (32, 32, 32, torch.int8, torch.float16, False), + (64, 64, 64, torch.int8, torch.float16, True), + ] + + try: + torch.tensor([1.0], dtype=torch.float8_e4m3fn, device="cuda") + test_configs.append((32, 32, 32, torch.float8_e4m3fn, torch.float16, False)) + except: + print("FP8 not supported, skipping") + + for M, K, N, in_dtype, out_dtype, with_bias in test_configs: + with self.subTest(M=M, K=K, N=N, dtype=in_dtype, bias=with_bias): + print(f"Currently testing with in_dtype: {in_dtype}") + torch.manual_seed(42) + + input, weight = self._make_inputs(M, K, N, in_dtype) + scale_a = 0.1 + 0.05 * torch.rand( + (M, 1), dtype=torch.float32, device="cuda" + ) + scale_b = 0.1 + 0.05 * torch.rand( + (N, 1), dtype=torch.float32, device="cuda" + ) + bias = ( + 0.01 * torch.randn((M, N), dtype=out_dtype, device="cuda") + if with_bias + else None + ) + + triton_out = triton_scaled_mm( + input, weight, scale_a, scale_b, out_dtype, bias + ) + ref_out = torch_scaled_mm( + input, weight, scale_a, scale_b, out_dtype, bias + ) + + # Use relaxed tolerances + rtol = 0.15 if in_dtype == torch.int8 else 0.25 + atol = 0.1 if in_dtype == torch.int8 else 0.15 + + torch.testing.assert_close(triton_out, ref_out, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 1bbdb7f65..b494cd9a7 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -57,6 +57,7 @@ suites = { TestFile("quant/test_block_int8.py", 22), TestFile("quant/test_fp8_kernel.py", 8), TestFile("quant/test_int8_kernel.py", 8), + TestFile("quant/test_triton_scaled_mm.py", 8), TestFile("quant/test_w8a8_quantization.py", 46), TestFile("rl/test_update_weights_from_disk.py", 114), TestFile("rl/test_update_weights_from_tensor.py", 48),