From 5bb0accbcffa0e47ab7987062d3895c330aa1f00 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Tue, 29 Apr 2025 12:52:36 +0800 Subject: [PATCH] cutlass 3.9 supported to improve fp8_blockwise_gemm (#5820) --- sgl-kernel/CMakeLists.txt | 2 +- .../csrc/gemm/fp8_blockwise_gemm_kernel.cu | 43 ++++++++----------- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index bbe6246eb..375b35d1c 100755 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -43,7 +43,7 @@ include(FetchContent) FetchContent_Declare( repo-cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass - GIT_TAG 5e497243f7ad13a2aa842143f9b10bbb23d98292 + GIT_TAG e94e888df3551224738bfa505787b515eae8352f GIT_SHALLOW OFF ) FetchContent_Populate(repo-cutlass) diff --git a/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu index b0b63df11..3ed96d067 100644 --- a/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu +++ b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu @@ -34,12 +34,7 @@ using namespace cute; -template < - typename SchedulerType, - typename OutType, - typename TileShape, - typename ClusterShape, - typename ScaleGranularity> +template void launch_sm90_fp8_blockwise_scaled_mm( torch::Tensor& out, const torch::Tensor& a, @@ -66,8 +61,10 @@ void launch_sm90_fp8_blockwise_scaled_mm( using LayoutD = cutlass::layout::RowMajor; constexpr int AlignmentD = AlignmentC; - static constexpr int ScaleGranularityM = size<0>(ScaleGranularity{}); - static constexpr int ScaleGranularityN = size<1>(ScaleGranularity{}); + using ScaleTileShape = Shape<_1, _128, _128>; + using ScaleConfig = decltype(cutlass::detail::sm90_trivial_blockwise_scale_config(ScaleTileShape{})); + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; @@ -75,8 +72,7 @@ void launch_sm90_fp8_blockwise_scaled_mm( using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT; - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, @@ -98,10 +94,10 @@ void launch_sm90_fp8_blockwise_scaled_mm( ArchTag, OperatorClass, ElementA, - LayoutA, + cute::tuple, AlignmentA, ElementB, - LayoutB, + cute::tuple, AlignmentB, ElementAccumulator, TileShape, @@ -140,7 +136,11 @@ void launch_sm90_fp8_blockwise_scaled_mm( StrideC stride_c; StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, 1)); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, b_s_ptr}; + LayoutSFA layout_sfa = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_sfb = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, stride_a, b_ptr, stride_b, 4, a_s_ptr, layout_sfa, b_s_ptr, layout_sfb}; typename GemmKernel::EpilogueArguments epilogue_args{{}, nullptr, stride_d, o_ptr, stride_d}; typename Gemm::Arguments args = { @@ -306,24 +306,15 @@ void sm90_fp8_blockwise_dispatch_shape( const torch::Tensor& scales_b) { using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _2, _1>; - using ScaleGranularity = Shape<_1, _128, _128>; auto k = a.size(1); auto n = b.size(1); if (k > 3 * n) { - launch_sm90_fp8_blockwise_scaled_mm< - cutlass::gemm::StreamKScheduler, - OutType, - TileShape, - ClusterShape, - ScaleGranularity>(out, a, b, scales_a, scales_b); + launch_sm90_fp8_blockwise_scaled_mm( + out, a, b, scales_a, scales_b); } else { - launch_sm90_fp8_blockwise_scaled_mm< - cutlass::gemm::PersistentScheduler, - OutType, - TileShape, - ClusterShape, - ScaleGranularity>(out, a, b, scales_a, scales_b); + launch_sm90_fp8_blockwise_scaled_mm( + out, a, b, scales_a, scales_b); } }