diff --git a/python/sglang/test/test_cutlass_moe.py b/python/sglang/test/test_cutlass_moe.py index 496e6d487..892cc4c87 100755 --- a/python/sglang/test/test_cutlass_moe.py +++ b/python/sglang/test/test_cutlass_moe.py @@ -153,9 +153,8 @@ def run_test(tp_size, batch_size, model_config, check=False): x, w1, w2, - topk_weights, - topk_ids, - inplace=False, # Use False for benchmarking to avoid side effects if run multiple times + (topk_weights, topk_ids, "dummy"), + inplace=False, activation="silu", # Assuming SiLU activation common in MoEs use_fp8_w8a8=True, w1_scale=w1_scale, @@ -221,8 +220,7 @@ def run_test(tp_size, batch_size, model_config, check=False): x, w1, # Original shape w2, # Original shape - topk_weights, - topk_ids, + (topk_weights, topk_ids, "dummy"), inplace=False, # Important: Use False to get output tensor activation="silu", use_fp8_w8a8=True, @@ -266,7 +264,7 @@ if __name__ == "__main__": "--batch-sizes", type=int, nargs="+", - default=[1, 4, 8, 16, 32, 64, 128, 256, 512], # Adjusted default + default=[1, 4, 8, 16, 32, 64, 128, 256, 512, 1024], # Adjusted default help="List of batch sizes to test", ) parser.add_argument("--check", action="store_true", help="Enable check mode") diff --git a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu index 748dd2137..d0cf45431 100644 --- a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu +++ b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu @@ -437,6 +437,34 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( } } +#define JOIN_STRUCT_PP_NAME(m, n, k, a, b, c) sm90_fp8_pp_config##_##m##_##n##_##k##_##a##_##b##_##c + +#define JOIN_STRUCT_CO_NAME(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c + +#define GENERATE_SM90_FP8_PP_CONFIG(M, N, K, A, B, C) \ + struct JOIN_STRUCT_PP_NAME(M, N, K, A, B, C) { \ + using ElementA = cutlass::float_e4m3_t; \ + using MmaTileShape = Shape, cute::Int, cute::Int>; \ + using ClusterShape = Shape, cute::Int, cute::Int>; \ + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; \ + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \ + using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \ + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \ + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \ + }; + +#define GENERATE_SM90_FP8_CO_CONFIG(M, N, K, A, B, C) \ + struct JOIN_STRUCT_CO_NAME(M, N, K, A, B, C) { \ + using ElementA = cutlass::float_e4m3_t; \ + using MmaTileShape = Shape, cute::Int, cute::Int>; \ + using ClusterShape = Shape, cute::Int, cute::Int>; \ + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; \ + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \ + using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \ + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \ + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \ + }; + template void sm90_fp8_blockwise_group_mm_dispatch_shape( torch::Tensor& output, @@ -481,13 +509,24 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; + // [NOTE] Tuned for H20 + GENERATE_SM90_FP8_PP_CONFIG(64, 128, 128, 1, 2, 1) + int num_experts = (int)expert_offsets.size(0); torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); - if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) { - // For H20 with K > 128, use Pingpong Schedule - run_get_group_gemm_starts( + bool tuning_H20_kernel = getBoolEnv("SGL_TUNE_DEVICE_KERNEL"); + + const std::string H20_device_type_str = "NVIDIA H20"; + bool is_h20 = isDeviceType(H20_device_type_str); + + if (is_h20 && tuning_H20_kernel) { + using execute_gemm_config = sm90_fp8_pp_config_64_128_128_1_2_1; + run_get_group_gemm_starts< + execute_gemm_config::LayoutSFA, + execute_gemm_config::LayoutSFB, + execute_gemm_config::ScaleConfig>( expert_offsets, a_ptrs, b_ptrs, @@ -503,7 +542,8 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( layout_sfb, problem_sizes, problem_sizes_transpose); - launch_sm90_fp8_blockwise_scaled_group_mm( + + launch_sm90_fp8_blockwise_scaled_group_mm( out_ptrs, a_ptrs, b_ptrs, @@ -518,37 +558,71 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( expert_offsets, workspace); } else { - // For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule - run_get_group_gemm_starts( - expert_offsets, - a_ptrs, - b_ptrs, - out_ptrs, - a_scales_ptrs, - b_scales_ptrs, - a, - b, - output, - scales_a, - scales_b, - layout_sfa, - layout_sfb, - problem_sizes, - problem_sizes_transpose); - launch_sm90_fp8_blockwise_scaled_group_mm( - out_ptrs, - a_ptrs, - b_ptrs, - a_scales_ptrs, - b_scales_ptrs, - stride_a, - stride_b, - stride_c, - layout_sfa, - layout_sfb, - problem_sizes, - expert_offsets, - workspace); + if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) { + // For H20 with K > 128, use Pingpong Schedule + run_get_group_gemm_starts( + expert_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a, + b, + output, + scales_a, + scales_b, + layout_sfa, + layout_sfb, + problem_sizes, + problem_sizes_transpose); + launch_sm90_fp8_blockwise_scaled_group_mm( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets, + workspace); + } else { + // For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule + run_get_group_gemm_starts( + expert_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a, + b, + output, + scales_a, + scales_b, + layout_sfa, + layout_sfb, + problem_sizes, + problem_sizes_transpose); + launch_sm90_fp8_blockwise_scaled_group_mm( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets, + workspace); + } } } diff --git a/sgl-kernel/include/utils.h b/sgl-kernel/include/utils.h index d7d0d5d1f..d78049a68 100644 --- a/sgl-kernel/include/utils.h +++ b/sgl-kernel/include/utils.h @@ -254,6 +254,25 @@ inline int getSMVersion() { return sm_major * 10 + sm_minor; } +inline bool isDeviceType(const std::string& device_type) { + int deviceCount; + CHECK_CUDA_SUCCESS(cudaGetDeviceCount(&deviceCount)); + + int device_id = -1; + if (deviceCount >= 1) { + CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id)); + } else { + return false; + } + + cudaDeviceProp prop; + CHECK_CUDA_SUCCESS(cudaGetDeviceProperties(&prop, device_id)); + if (device_type == std::string(prop.name)) { + return true; + } + return false; +} + inline bool getBoolEnv(char const* name) { char const* env = std::getenv(name); return env && env[0] == '1' && env[1] == '\0';