From 5fd311d33e625edf2c9b53b1681cf27c55a84013 Mon Sep 17 00:00:00 2001 From: kousakawang Date: Fri, 22 Aug 2025 10:23:29 +0800 Subject: [PATCH] [code clean] add H20 cutlass groupGemm default config (#9333) Co-authored-by: wanghanpei --- .../csrc/moe/fp8_blockwise_moe_kernel.cu | 50 ++++++------------- 1 file changed, 15 insertions(+), 35 deletions(-) diff --git a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu index d0cf45431..aad3ce1fa 100644 --- a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu +++ b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu @@ -437,34 +437,6 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( } } -#define JOIN_STRUCT_PP_NAME(m, n, k, a, b, c) sm90_fp8_pp_config##_##m##_##n##_##k##_##a##_##b##_##c - -#define JOIN_STRUCT_CO_NAME(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c - -#define GENERATE_SM90_FP8_PP_CONFIG(M, N, K, A, B, C) \ - struct JOIN_STRUCT_PP_NAME(M, N, K, A, B, C) { \ - using ElementA = cutlass::float_e4m3_t; \ - using MmaTileShape = Shape, cute::Int, cute::Int>; \ - using ClusterShape = Shape, cute::Int, cute::Int>; \ - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; \ - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \ - using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \ - using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \ - using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \ - }; - -#define GENERATE_SM90_FP8_CO_CONFIG(M, N, K, A, B, C) \ - struct JOIN_STRUCT_CO_NAME(M, N, K, A, B, C) { \ - using ElementA = cutlass::float_e4m3_t; \ - using MmaTileShape = Shape, cute::Int, cute::Int>; \ - using ClusterShape = Shape, cute::Int, cute::Int>; \ - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; \ - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \ - using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \ - using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \ - using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \ - }; - template void sm90_fp8_blockwise_group_mm_dispatch_shape( torch::Tensor& output, @@ -509,20 +481,28 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; - // [NOTE] Tuned for H20 - GENERATE_SM90_FP8_PP_CONFIG(64, 128, 128, 1, 2, 1) + // [NOTE] default for H20 + struct MmaConfigH20_default { + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; + + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + }; int num_experts = (int)expert_offsets.size(0); torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); - bool tuning_H20_kernel = getBoolEnv("SGL_TUNE_DEVICE_KERNEL"); - const std::string H20_device_type_str = "NVIDIA H20"; - bool is_h20 = isDeviceType(H20_device_type_str); + bool is_h20_device = isDeviceType(H20_device_type_str); - if (is_h20 && tuning_H20_kernel) { - using execute_gemm_config = sm90_fp8_pp_config_64_128_128_1_2_1; + if (is_h20_device) { + using execute_gemm_config = MmaConfigH20_default; run_get_group_gemm_starts< execute_gemm_config::LayoutSFA, execute_gemm_config::LayoutSFB,