From 8e9fb43d825508dfacd6b970f19337eb2f755a7a Mon Sep 17 00:00:00 2001 From: Qi Yuhang <45795032+HydraQYH@users.noreply.github.com> Date: Sat, 5 Jul 2025 13:25:49 +0800 Subject: [PATCH] Optimize Hopper CUTLASS FP8 Blockwise Grouped GEMM Kernel in Small K Scenario (#7782) --- .../csrc/moe/fp8_blockwise_moe_kernel.cu | 124 ++++++++++++------ 1 file changed, 86 insertions(+), 38 deletions(-) diff --git a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu index dc022bcc9..d252c29c2 100644 --- a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu +++ b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu @@ -61,7 +61,12 @@ void launch_sm90_fp8_blockwise_scaled_group_mm( using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; - using FusionOperation = cutlass::epilogue::fusion::LinearCombination; + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using CustomEVTIdentity = // acc + cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion:: + Sm90Compute, + cutlass::epilogue::fusion::Sm90AccFetch>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, @@ -78,7 +83,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm( LayoutC*, AlignmentC, typename ScheduleConfig::EpilogueSchedule, - FusionOperation>::CollectiveOp; + CustomEVTIdentity>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, @@ -452,7 +457,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, const torch::Tensor& workspace) { - struct MmaConfig { + struct MmaConfig0 { using ElementA = cutlass::float_e4m3_t; using MmaTileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; @@ -464,40 +469,86 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); }; + struct MmaConfig1 { + using ElementA = cutlass::float_e4m3_t; + using MmaTileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; + + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + }; + int num_experts = (int)expert_offsets.size(0); torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); - run_get_group_gemm_starts( - expert_offsets, - a_ptrs, - b_ptrs, - out_ptrs, - a_scales_ptrs, - b_scales_ptrs, - a, - b, - output, - scales_a, - scales_b, - layout_sfa, - layout_sfb, - problem_sizes, - problem_sizes_transpose); - launch_sm90_fp8_blockwise_scaled_group_mm( - out_ptrs, - a_ptrs, - b_ptrs, - a_scales_ptrs, - b_scales_ptrs, - stride_a, - stride_b, - stride_c, - layout_sfa, - layout_sfb, - problem_sizes, - expert_offsets, - workspace); + if (a.size(1) > 128) { + run_get_group_gemm_starts( + expert_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a, + b, + output, + scales_a, + scales_b, + layout_sfa, + layout_sfb, + problem_sizes, + problem_sizes_transpose); + launch_sm90_fp8_blockwise_scaled_group_mm( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets, + workspace); + } else { + // Small K + run_get_group_gemm_starts( + expert_offsets, + a_ptrs, + b_ptrs, + out_ptrs, + a_scales_ptrs, + b_scales_ptrs, + a, + b, + output, + scales_a, + scales_b, + layout_sfa, + layout_sfb, + problem_sizes, + problem_sizes_transpose); + launch_sm90_fp8_blockwise_scaled_group_mm( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + stride_a, + stride_b, + stride_c, + layout_sfa, + layout_sfb, + problem_sizes, + expert_offsets, + workspace); + } } /** @@ -641,7 +692,7 @@ void fp8_blockwise_scaled_grouped_mm( #endif #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) - if (sm_version == 90 && a.size(1) > 256) { + if (sm_version == 90) { if (output.scalar_type() == torch::kBFloat16) { sm90_fp8_blockwise_group_mm_dispatch_shape( output, @@ -687,8 +738,5 @@ void fp8_blockwise_scaled_grouped_mm( } #endif TORCH_CHECK_NOT_IMPLEMENTED( - can_implement, - "No implemented fp8_blockwise_scaled_mm for current compute capability or K size: ", - sm_version, - a.size(1)); + can_implement, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); }