From 85ed8e0a5ec9ed47a5d1ac93ecc8bee2159468e0 Mon Sep 17 00:00:00 2001 From: Qi Yuhang <45795032+HydraQYH@users.noreply.github.com> Date: Sun, 7 Sep 2025 13:31:00 +0800 Subject: [PATCH] Optimize nvfp4 block scaled gemm kernel when M is small. (#10101) --- .../csrc/gemm/nvfp4_scaled_mm_kernels.cu | 144 ++++++++++++++---- 1 file changed, 114 insertions(+), 30 deletions(-) diff --git a/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu index cc4804298..a103545dd 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu @@ -38,27 +38,74 @@ limitations under the License. using namespace cute; #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) -// Kernel Perf config +// Config(half_t/bfloat16_t) for M <= 128 template -struct KernelTraits { +struct KernelConfigM128 { + using OutputType = T; + using MmaTileShape = Shape<_128, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; // Avoid register spilling + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +template +const dim3 KernelConfigM128::preferred_cluster(1, 4, 1); +template +const dim3 KernelConfigM128::fallback_cluster(1, 2, 1); + +// Config(half_t/bfloat16_t) for M <= 256 +template +struct KernelConfigM256 { + using OutputType = T; using MmaTileShape = Shape<_256, _256, _256>; using ClusterShape = Shape; - using EpilogueTile = Shape<_128, _64>; + using EpilogueTile = Shape<_128, _64>; // Avoid register spilling using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; }; +template +const dim3 KernelConfigM256::preferred_cluster(2, 4, 1); +template +const dim3 KernelConfigM256::fallback_cluster(2, 1, 1); -template <> -struct KernelTraits { +// Default config(half_t/bfloat16_t) for M > 256 +template +struct KernelConfigDefault { + using OutputType = T; + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape; + using EpilogueTile = Shape<_128, _64>; // Avoid register spilling + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized2Sm; + using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +template +const dim3 KernelConfigDefault::preferred_cluster(4, 4, 1); +template +const dim3 KernelConfigDefault::fallback_cluster(2, 1, 1); + +struct KernelConfigFp32 { + using OutputType = float; using MmaTileShape = Shape<_128, _128, _256>; using ClusterShape = Shape; using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto; using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized1Sm; using MainloopSchedule = cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; }; +const dim3 KernelConfigFp32::preferred_cluster = dim3(1, 4, 1); +const dim3 KernelConfigFp32::fallback_cluster = dim3(1, 2, 1); -template +template struct Fp4GemmSm100 { + using Config = KernelConfig; // For generating args + using OutputType = typename KernelConfig::OutputType; // A matrix configuration using ElementA = cutlass::nv_float4_t; using LayoutATag = cutlass::layout::RowMajor; @@ -70,8 +117,8 @@ struct Fp4GemmSm100 { static constexpr int AlignmentB = 32; // C/D matrix configuration - using ElementD = T; - using ElementC = T; + using ElementD = OutputType; + using ElementC = OutputType; using LayoutCTag = cutlass::layout::RowMajor; using LayoutDTag = cutlass::layout::RowMajor; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; @@ -82,15 +129,15 @@ struct Fp4GemmSm100 { using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Kernel Perf config - using MmaTileShape = typename KernelTraits::MmaTileShape; - using ClusterShape = typename KernelTraits::ClusterShape; - using EpilogueTile = typename KernelTraits::EpilogueTile; - using EpilogueSchedule = typename KernelTraits::EpilogueSchedule; - using MainloopSchedule = typename KernelTraits::MainloopSchedule; + using MmaTileShape = typename KernelConfig::MmaTileShape; + using ClusterShape = typename KernelConfig::ClusterShape; + using EpilogueTile = typename KernelConfig::EpilogueTile; + using EpilogueSchedule = typename KernelConfig::EpilogueSchedule; + using MainloopSchedule = typename KernelConfig::MainloopSchedule; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, - cutlass::arch::OpClassTensorOp, + OperatorClass, MmaTileShape, ClusterShape, EpilogueTile, @@ -182,19 +229,15 @@ typename T::Gemm::Arguments args_from_options( layout_SFB}, { // Epilogue arguments {}, // epilogue.thread - static_cast(D.data_ptr()), + nullptr, stride_D, static_cast(D.data_ptr()), stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - if constexpr (std::is_same_v) { - arguments.hw_info.cluster_shape = dim3(1, 4, 1); - arguments.hw_info.cluster_shape_fallback = dim3(1, 1, 1); - } else { - arguments.hw_info.cluster_shape = dim3(4, 4, 1); - arguments.hw_info.cluster_shape_fallback = dim3(2, 1, 1); - } + using KernelConfig = typename T::Config; + arguments.hw_info.cluster_shape = KernelConfig::preferred_cluster; + arguments.hw_info.cluster_shape_fallback = KernelConfig::fallback_cluster; return arguments; } @@ -210,11 +253,10 @@ void runGemm( int64_t n, int64_t k, cudaStream_t stream) { - typename Fp4GemmSm100::Gemm gemm; + typename T::Gemm gemm; + auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, m, n, k); - auto arguments = args_from_options>(D, A, B, A_sf, B_sf, alpha, m, n, k); - - size_t workspace_size = Fp4GemmSm100::Gemm::get_workspace_size(arguments); + size_t workspace_size = T::Gemm::get_workspace_size(arguments); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); auto workspace = torch::empty(workspace_size, workspace_options); @@ -224,9 +266,51 @@ void runGemm( CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); } + +// Dispatch function to select appropriate config based on M +template +void cutlassFp4GemmDispatch( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + if (m <= 128) { + // m in [1, 128] + runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (m <= 256) { + // m in (128, 256] + runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + // m in (256, inf) + runGemm>>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + +// Dispatch function to select appropriate config based on M +template <> +void cutlassFp4GemmDispatch( + torch::Tensor& D, + torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, + int64_t m, + int64_t n, + int64_t k, + cudaStream_t stream) { + runGemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); +} + #else template -void runGemm( +void cutlassFp4GemmDispatch( at::Tensor& D, at::Tensor const& A, at::Tensor const& B, @@ -358,11 +442,11 @@ void cutlass_scaled_fp4_mm_sm100a( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); if (out_dtype == at::ScalarType::Half) { - runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlassFp4GemmDispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else if (out_dtype == at::ScalarType::BFloat16) { - runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlassFp4GemmDispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else if (out_dtype == at::ScalarType::Float) { - runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlassFp4GemmDispatch(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else { TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm"); }