From c268c11c710a8f29346628b532e6d82dba501151 Mon Sep 17 00:00:00 2001 From: Qi Yuhang <45795032+HydraQYH@users.noreply.github.com> Date: Tue, 15 Jul 2025 15:02:44 +0800 Subject: [PATCH] [feat]Support fusion kernel for constructing quant input and scale factor for fp8_blockwise_scaled_grouped_mm (#8023) --- .../srt/layers/quantization/fp8_kernel.py | 85 +++++++++++++++++++ sgl-kernel/tests/test_fp8_blockwise_moe.py | 43 +++++++--- 2 files changed, 117 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index d73f5bbab..7d73c5bc2 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -1166,3 +1166,88 @@ def scaled_fp8_quant( ) # True for static return output, scale + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": block_m}, num_warps=num_warps) + for block_m in [16, 32, 64, 128] + for num_warps in [2, 4, 8] + ], + key=["K", "BLOCK_K", "M_ALIGNMENT"], +) +@triton.jit +def _per_token_group_quant_fp8_hopper_moe_mn_major( + a, # (M, K):(K, 1) + expert_offsets, # (num_experts,) + problem_sizes, # (num_experts, 3) + a_fp8, # (M, K):(K, 1) + sfa, # (M, k) + K: tl.constexpr, + BLOCK_K: tl.constexpr, + M_ALIGNMENT: tl.constexpr, + BLOCK_M: tl.constexpr, # tune +): + k_offset = tl.program_id(0) + expert_id = tl.program_id(1) + + m = tl.load(problem_sizes + expert_id * 3) + current_expert_offset = tl.load(expert_offsets + expert_id).to(tl.int64) + tl.multiple_of(m, M_ALIGNMENT) + tl.multiple_of(current_expert_offset, M_ALIGNMENT) + + coord_k = k_offset * BLOCK_K + tl.arange(0, BLOCK_K) + for i in tl.range(tl.cdiv(m, BLOCK_M)): + coord_m = i * BLOCK_M + tl.arange(0, BLOCK_M) + a_ptrs = a + current_expert_offset * K + coord_m[:, None] * K + coord_k[None, :] + a_mask = (coord_m < m)[:, None] & (coord_k < K)[None, :] + + inp = tl.load(a_ptrs, mask=a_mask).to(tl.float32) # [BLOCK_M, BLOCK_K] + inp_amax = tl.max(tl.abs(inp), axis=1) # [BLOCK_M,] + inp_amax = tl.clamp(inp_amax, min=1e-4, max=float("inf")) + inp_fp8 = (inp * (448.0 / inp_amax[:, None])).to(tl.float8e4nv) + + # Store fp8 + a_fp8_ptrs = ( + a_fp8 + current_expert_offset * K + coord_m[:, None] * K + coord_k[None, :] + ) + tl.store(a_fp8_ptrs, inp_fp8, mask=a_mask) + + # Store sfa + k = tl.cdiv(K, BLOCK_K) + sfa_ptrs = ( + sfa + current_expert_offset * k + k_offset * m + coord_m + ) # MN-Major with sfa + tl.store(sfa_ptrs, inp_amax / 448.0, mask=coord_m < m) + + +def per_token_group_quant_fp8_hopper_moe_mn_major( + A: torch.Tensor, + expert_offsets: torch.Tensor, + problem_sizes: torch.Tensor, + group_size: int, + expert_tokens_alignment: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert A.dim() == 2 + assert A.is_contiguous(), "`A` is not contiguous" + assert ( + A.shape[-1] % group_size == 0 + ), "the last dimension of `A` cannot be divisible by `group_size`" + + a_q = torch.empty_like(A, device=A.device, dtype=fp8_dtype) + M, K = A.shape[0], A.shape[1] + k = K // group_size + sfa = torch.empty((M, k), device=A.device, dtype=torch.float32) + num_experts = problem_sizes.shape[0] + grid = (k, num_experts) + _per_token_group_quant_fp8_hopper_moe_mn_major[grid]( + A, + expert_offsets, + problem_sizes, + a_q, + sfa, + K, + group_size, + expert_tokens_alignment, + ) + return a_q, sfa diff --git a/sgl-kernel/tests/test_fp8_blockwise_moe.py b/sgl-kernel/tests/test_fp8_blockwise_moe.py index 06e5290a4..decb3e2fc 100755 --- a/sgl-kernel/tests/test_fp8_blockwise_moe.py +++ b/sgl-kernel/tests/test_fp8_blockwise_moe.py @@ -5,6 +5,10 @@ import pytest import torch from sgl_kernel import fp8_blockwise_scaled_grouped_mm +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_fp8_hopper_moe_mn_major, +) + def cdiv(a: int, b: int) -> int: return -(a // -b) @@ -104,8 +108,11 @@ def is_sm90_supported(device=None) -> bool: ) @pytest.mark.parametrize("num_experts", [8, 16]) @pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16]) -def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): +@pytest.mark.parametrize("use_custom_kernel", [True, False]) +def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kernel): cc = torch.cuda.get_device_capability(None)[0] + if cc == 10 and use_custom_kernel: + return device = "cuda" alignment = 16 n_g = alignment * random.randint(1, 5) * 128 @@ -116,6 +123,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32) layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32) + a_original_tensors = [] a_tensors = [] b_tensors = [] a_scales_tensors = [] @@ -136,6 +144,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): b_g, b_scale = per_block_cast_to_fp8( b ) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1) + a_original_tensors.append(a) a_tensors.append(a_g) b_tensors.append(b_g) a_scales_tensors.append(a_scale) @@ -143,22 +152,15 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): baseline = torch.mm(a, b) baseline_tensors.append(baseline) - + a_original_stack = torch.empty( + (expert_offsets[-1], k_g), device=device, dtype=out_dtype + ) a_stack = torch.empty( (expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn ) b_stack = torch.empty( (num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn ) - - for g in range(num_experts): - # Matrix A is Row-Major. - a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[ - g - ] # a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1) - b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1) - b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major. - a_scale_stack = torch.empty( (expert_offsets[-1] * (k_g // 128)), device=device, dtype=torch.float32 ) @@ -167,6 +169,14 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): ) for g in range(num_experts): + # Matrix A is Row-Major. + a_original_stack[expert_offsets[g] : expert_offsets[g + 1]] = ( + a_original_tensors[g] + ) + a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[ + g + ] # a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1) + b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1) if cc == 9: # For SM90, we need MN-Major scale factor # a_scales_tensors[g] -- (M, k):(k, 1) @@ -185,9 +195,20 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): g ] # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later a_scale_stack = a_scale_stack.view(expert_offsets[-1], k_g // 128) + b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major. if cc == 10: b_scale_stack = b_scale_stack.transpose(1, 2).contiguous() + if use_custom_kernel: + # Replace a_stack, a_scale_stack with custom kernel output + a_stack, a_scale_stack = per_token_group_quant_fp8_hopper_moe_mn_major( + a_original_stack, + expert_offsets[:-1], + problem_sizes, + 128, + expert_tokens_alignment=alignment, + ) + c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype) a_strides = torch.full( (num_experts,), a_stack.stride(0), device=device, dtype=torch.int64