From 6e92da8fca18c746a0aa15c7bd95b47b6827befa Mon Sep 17 00:00:00 2001 From: Qi Yuhang <45795032+HydraQYH@users.noreply.github.com> Date: Fri, 18 Jul 2025 11:49:36 +0800 Subject: [PATCH] [Fix][Ready]Fix register spilling in cutlass nvfp4 gemm kernel on Blackwell (#8127) --- .../csrc/gemm/nvfp4_scaled_mm_kernels.cu | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu index 4fc4972dc..d1193ea44 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu @@ -40,27 +40,21 @@ using namespace cute; #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) // Kernel Perf config template -struct KernelTraits; +struct KernelTraits { + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; +}; template <> struct KernelTraits { using MmaTileShape = Shape<_128, _128, _256>; - using ClusterShape = Shape<_1, _1, _1>; - using PerSmTileShape_MNK = Shape<_128, _128, _256>; -}; - -template <> -struct KernelTraits { - using MmaTileShape = Shape<_256, _256, _256>; - using ClusterShape = Shape<_4, _4, _1>; - using PerSmTileShape_MNK = Shape<_128, _256, _256>; -}; - -template <> -struct KernelTraits { - using MmaTileShape = Shape<_256, _256, _256>; - using ClusterShape = Shape<_4, _4, _1>; - using PerSmTileShape_MNK = Shape<_128, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; }; template @@ -90,23 +84,26 @@ struct Fp4GemmSm100 { // Kernel Perf config using MmaTileShape = typename KernelTraits::MmaTileShape; using ClusterShape = typename KernelTraits::ClusterShape; - using PerSmTileShape_MNK = typename KernelTraits::PerSmTileShape_MNK; + using EpilogueTile = typename KernelTraits::EpilogueTile; + using EpilogueSchedule = typename KernelTraits::EpilogueSchedule; + using MainloopSchedule = typename KernelTraits::MainloopSchedule; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, - OperatorClass, - PerSmTileShape_MNK, + cutlass::arch::OpClassTensorOp, + MmaTileShape, ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, + EpilogueTile, ElementAccumulator, ElementAccumulator, - ElementC, + void, LayoutCTag, AlignmentC, ElementD, LayoutDTag, AlignmentD, - cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, @@ -122,7 +119,7 @@ struct Fp4GemmSm100 { ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + MainloopSchedule>::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>; @@ -191,6 +188,13 @@ typename T::Gemm::Arguments args_from_options( stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + if constexpr (std::is_same_v) { + arguments.hw_info.cluster_shape = dim3(1, 4, 1); + arguments.hw_info.cluster_shape_fallback = dim3(1, 1, 1); + } else { + arguments.hw_info.cluster_shape = dim3(4, 4, 1); + arguments.hw_info.cluster_shape_fallback = dim3(2, 1, 1); + } return arguments; }