diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 262f1ae39..4d9868710 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -157,10 +157,6 @@ def cutlass_fused_experts_fp8( rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k)) rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128))) - if not is_sm100_supported(): - rep_a1_scales = per_group_transpose(rep_a1_scales, expert_offsets) - w1_scale = w1_scale.contiguous() - c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) @@ -192,9 +188,6 @@ def cutlass_fused_experts_fp8( silu_and_mul(c1, intermediate) intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128) - if not is_sm100_supported(): - a2_scale = per_group_transpose(a2_scale, expert_offsets) - w2_scale = w2_scale.contiguous() fp8_blockwise_scaled_grouped_mm( c2, diff --git a/python/sglang/test/test_cutlass_moe.py b/python/sglang/test/test_cutlass_moe.py index 892cc4c87..4a67ab3b6 100755 --- a/python/sglang/test/test_cutlass_moe.py +++ b/python/sglang/test/test_cutlass_moe.py @@ -8,6 +8,15 @@ from transformers import AutoConfig from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts +from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig + + +# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim def get_model_config(tp_size: int): @@ -69,16 +78,11 @@ def run_test(tp_size, batch_size, model_config, check=False): # --- Input Data --- # Use bf16/fp16 for input activation based on model config - x = torch.randn((batch_size, H), device="cuda", dtype=dtype) * 0.0001 + x = torch.randn((batch_size, H), device="cuda", dtype=dtype) # --- Weights (Generate in higher precision, then convert to FP8) --- # Generate weights suitable for FP8 conversion (e.g., scaled appropriately) - w1_hp = ( - torch.randn((E, I, H), device="cuda", dtype=torch.float32) * 0.00001 + 0.00001 - ) - w2_hp = ( - torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32) * 0.00001 - + 0.00001 - ) + w1_hp = torch.randn((E, I, H), device="cuda", dtype=torch.float32) + w2_hp = torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32) w1 = to_fp8(w1_hp) w2 = to_fp8(w2_hp) @@ -149,13 +153,13 @@ def run_test(tp_size, batch_size, model_config, check=False): ) # Note: Triton expects non-transposed weights + moe_config = MoeRunnerConfig(inplace=False) triton_lambda = lambda: fused_experts( x, w1, w2, (topk_weights, topk_ids, "dummy"), - inplace=False, - activation="silu", # Assuming SiLU activation common in MoEs + moe_config, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, @@ -221,32 +225,19 @@ def run_test(tp_size, batch_size, model_config, check=False): w1, # Original shape w2, # Original shape (topk_weights, topk_ids, "dummy"), - inplace=False, # Important: Use False to get output tensor - activation="silu", + moe_config, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, block_shape=block_shape, ) - # Ensure outputs are same dtype for comparison - y_cutlass = y_cutlass.to(dtype) - y_triton = y_triton.to(dtype) - - abs_error = torch.abs(y_cutlass - y_triton) - rel_error = abs_error / torch.clamp(torch.abs(y_triton), min=1e-2) - - max_abs_err = abs_error.max().item() - max_rel_err = rel_error.max().item() - - print("y_cutlass:", y_cutlass[:, :10]) - print("y_triton:", y_triton[:, :10]) - print(f"Max absolute error: {max_abs_err:.6f}") - print(f"Max relative error: {max_rel_err:.6f}") + diff = calc_diff(y_cutlass, y_triton) + print(f"Diff: {diff:.6f}") # Tolerance might need adjustment based on FP8 specifics and kernel differences # FP8 comparisons often require higher tolerance than FP16/BF16 - assert max_rel_err < 5e-1, f"Relative error too high! {max_rel_err}" + assert diff < 1e-4, f"Diff too high! {diff}" print("Correctness check passed.") @@ -264,7 +255,21 @@ if __name__ == "__main__": "--batch-sizes", type=int, nargs="+", - default=[1, 4, 8, 16, 32, 64, 128, 256, 512, 1024], # Adjusted default + default=[ + 1, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + ], # Adjusted default help="List of batch sizes to test", ) parser.add_argument("--check", action="store_true", help="Enable check mode") diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 09ec8b00f..307734ca7 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -45,7 +45,7 @@ include(FetchContent) FetchContent_Declare( repo-cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass - GIT_TAG 664c4f7b3ed1959414905025728eef5568209479 + GIT_TAG a49a78ffefc86a87160dfe0ccc3a3a2d1622c918 GIT_SHALLOW OFF ) FetchContent_Populate(repo-cutlass) diff --git a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu index aad3ce1fa..1a11ce2d7 100644 --- a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu +++ b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu @@ -457,39 +457,40 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, const torch::Tensor& workspace) { - struct MmaConfig0 { + struct MmaConfigSmallM { + // Swap A/B + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_128, _32, _128>; + using ClusterShape = Shape<_2, _1, _1>; + // TODO: Check Pingpong or Cooperative + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using ScaleConfig = + cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + }; + + struct MmaConfigH20LargeK { using ElementA = cutlass::float_e4m3_t; using MmaTileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; - + using ScaleConfig = + cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; - struct MmaConfig1 { + struct MmaConfigHx00AndH20SmallK { using ElementA = cutlass::float_e4m3_t; using MmaTileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _2, _1>; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; - using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; - - using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); - using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); - }; - - // [NOTE] default for H20 - struct MmaConfigH20_default { - using ElementA = cutlass::float_e4m3_t; - using MmaTileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_1, _2, _1>; - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; - + using ScaleConfig = + cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; @@ -497,33 +498,34 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( int num_experts = (int)expert_offsets.size(0); torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); + torch::Tensor output_t = output.t(); + torch::Tensor a_t = a.t(); + torch::Tensor b_t = b.transpose(1, 2); + torch::Tensor scales_a_t = scales_a.t(); + torch::Tensor scales_b_t = scales_b.transpose(1, 2); - const std::string H20_device_type_str = "NVIDIA H20"; - bool is_h20_device = isDeviceType(H20_device_type_str); + const std::string H20_device_type_str("NVIDIA H20"); + bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str; - if (is_h20_device) { - using execute_gemm_config = MmaConfigH20_default; - run_get_group_gemm_starts< - execute_gemm_config::LayoutSFA, - execute_gemm_config::LayoutSFB, - execute_gemm_config::ScaleConfig>( + if (a.size(0) <= 2048) { + run_get_group_gemm_starts( expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, - a, - b, - output, - scales_a, - scales_b, + b_t, + a_t, + output_t, + scales_b_t, + scales_a_t, layout_sfa, layout_sfb, problem_sizes, - problem_sizes_transpose); - - launch_sm90_fp8_blockwise_scaled_group_mm( + problem_sizes_transpose, + true); + launch_sm90_fp8_blockwise_scaled_group_mm( out_ptrs, a_ptrs, b_ptrs, @@ -534,13 +536,17 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( stride_c, layout_sfa, layout_sfb, - problem_sizes, + problem_sizes_transpose, expert_offsets, workspace); + output = output_t.t(); } else { - if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) { + if (is_h20_device && a.size(1) > 128) { // For H20 with K > 128, use Pingpong Schedule - run_get_group_gemm_starts( + run_get_group_gemm_starts< + MmaConfigH20LargeK::LayoutSFA, + MmaConfigH20LargeK::LayoutSFB, + MmaConfigH20LargeK::ScaleConfig>( expert_offsets, a_ptrs, b_ptrs, @@ -556,7 +562,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( layout_sfb, problem_sizes, problem_sizes_transpose); - launch_sm90_fp8_blockwise_scaled_group_mm( + launch_sm90_fp8_blockwise_scaled_group_mm( out_ptrs, a_ptrs, b_ptrs, @@ -572,7 +578,10 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( workspace); } else { // For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule - run_get_group_gemm_starts( + run_get_group_gemm_starts< + MmaConfigHx00AndH20SmallK::LayoutSFA, + MmaConfigHx00AndH20SmallK::LayoutSFB, + MmaConfigHx00AndH20SmallK::ScaleConfig>( expert_offsets, a_ptrs, b_ptrs, @@ -588,7 +597,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( layout_sfb, problem_sizes, problem_sizes_transpose); - launch_sm90_fp8_blockwise_scaled_group_mm( + launch_sm90_fp8_blockwise_scaled_group_mm( out_ptrs, a_ptrs, b_ptrs, diff --git a/sgl-kernel/tests/test_fp8_blockwise_moe.py b/sgl-kernel/tests/test_fp8_blockwise_moe.py index decb3e2fc..6f2279393 100755 --- a/sgl-kernel/tests/test_fp8_blockwise_moe.py +++ b/sgl-kernel/tests/test_fp8_blockwise_moe.py @@ -5,10 +5,6 @@ import pytest import torch from sgl_kernel import fp8_blockwise_scaled_grouped_mm -from sglang.srt.layers.quantization.fp8_kernel import ( - per_token_group_quant_fp8_hopper_moe_mn_major, -) - def cdiv(a: int, b: int) -> int: return -(a // -b) @@ -106,24 +102,19 @@ def is_sm90_supported(device=None) -> bool: not (is_sm100_supported() or is_sm90_supported()), reason="fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90", ) -@pytest.mark.parametrize("num_experts", [8, 16]) +@pytest.mark.parametrize("num_experts", [8, 16, 32, 64, 128]) @pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16]) -@pytest.mark.parametrize("use_custom_kernel", [True, False]) -def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kernel): - cc = torch.cuda.get_device_capability(None)[0] - if cc == 10 and use_custom_kernel: - return +def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): device = "cuda" - alignment = 16 - n_g = alignment * random.randint(1, 5) * 128 - k_g = alignment * random.randint(1, 5) * 128 + alignment = 128 + n_g = random.randint(1, 64) * 128 + k_g = random.randint(1, 64) * 128 expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int32) problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32) layout_sfa = torch.zeros((num_experts, 5), device=device, dtype=torch.int32) layout_sfb = torch.zeros((num_experts, 5), device=device, dtype=torch.int32) - a_original_tensors = [] a_tensors = [] b_tensors = [] a_scales_tensors = [] @@ -131,7 +122,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern baseline_tensors = [] for g in range(num_experts): - m_g = alignment * random.randint(1, 64) + m_g = random.randint(1, 256) expert_offsets[g + 1] = expert_offsets[g] + m_g problem_sizes[g][:] = torch.tensor([m_g, n_g, k_g], device=device) @@ -144,7 +135,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern b_g, b_scale = per_block_cast_to_fp8( b ) # bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1) - a_original_tensors.append(a) a_tensors.append(a_g) b_tensors.append(b_g) a_scales_tensors.append(a_scale) @@ -152,9 +142,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern baseline = torch.mm(a, b) baseline_tensors.append(baseline) - a_original_stack = torch.empty( - (expert_offsets[-1], k_g), device=device, dtype=out_dtype - ) a_stack = torch.empty( (expert_offsets[-1], k_g), device=device, dtype=torch.float8_e4m3fn ) @@ -162,52 +149,28 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern (num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn ) a_scale_stack = torch.empty( - (expert_offsets[-1] * (k_g // 128)), device=device, dtype=torch.float32 + (expert_offsets[-1], (k_g // 128)), device=device, dtype=torch.float32 ) b_scale_stack = torch.empty( - (num_experts, k_g // 128, n_g // 128), device=device, dtype=torch.float32 + (num_experts, n_g // 128, k_g // 128), device=device, dtype=torch.float32 ) for g in range(num_experts): # Matrix A is Row-Major. - a_original_stack[expert_offsets[g] : expert_offsets[g + 1]] = ( - a_original_tensors[g] - ) - a_stack[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[ + a_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_tensors[ g - ] # a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1) + ] # a_stack[expert_offsets[g] : expert_offsets[g + 1], :] -- (M, K):(K, 1) b_stack[g] = b_tensors[g].t() # b_stack[g] -- (N, K):(K, 1) - if cc == 9: - # For SM90, we need MN-Major scale factor - # a_scales_tensors[g] -- (M, k):(k, 1) - # a_scales_tensors[g].t().contiguous() -- (k, M):(M, 1) - a_scale_stack[ - expert_offsets[g] * (k_g // 128) : expert_offsets[g + 1] * (k_g // 128) - ] = (a_scales_tensors[g].t().contiguous().view(-1)) - b_scale_stack[g] = b_scales_tensors[g] # b_scale_stack[g] -- (k, n):(n, 1) - elif cc == 10: - # For SM100, we need K-Major scale factor - # a_scales_tensors[g] -- (M, k):(k, 1) - a_scale_stack[ - expert_offsets[g] * (k_g // 128) : expert_offsets[g + 1] * (k_g // 128) - ] = a_scales_tensors[g].view(-1) - b_scale_stack[g] = b_scales_tensors[ - g - ] # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later - a_scale_stack = a_scale_stack.view(expert_offsets[-1], k_g // 128) - b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major. - if cc == 10: - b_scale_stack = b_scale_stack.transpose(1, 2).contiguous() - if use_custom_kernel: - # Replace a_stack, a_scale_stack with custom kernel output - a_stack, a_scale_stack = per_token_group_quant_fp8_hopper_moe_mn_major( - a_original_stack, - expert_offsets[:-1], - problem_sizes, - 128, - expert_tokens_alignment=alignment, - ) + # We need K-Major scale factor + a_scale_stack[expert_offsets[g] : expert_offsets[g + 1], :] = a_scales_tensors[ + g + ] + b_scale_stack[g] = b_scales_tensors[ + g + ].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later + b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major. + b_scale_stack = b_scale_stack.transpose(1, 2) c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype) a_strides = torch.full( @@ -250,7 +213,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern diff = calc_diff(actual, baseline) assert diff < 0.001 print( - f"cc={cc}0 num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK" + f"m_g={baseline.shape[0]} n_g={n_g} k_g={k_g} num_experts={num_experts}, out_dtype={out_dtype}, diff={diff:.5f}: OK" )