From 9b876889b7d1bdb63734de5cafabc9424042bc14 Mon Sep 17 00:00:00 2001 From: Qi Yuhang <45795032+HydraQYH@users.noreply.github.com> Date: Tue, 16 Sep 2025 17:47:37 +0800 Subject: [PATCH] Update CUTLASS. Refine KernelSchedule for fp8 (grouped) gemm. (#10491) --- sgl-kernel/CMakeLists.txt | 2 +- .../gemm/fp8_blockwise_gemm_sm90_dispatch.cuh | 2 +- sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 5fd991e30..3a1c64605 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -46,7 +46,7 @@ include(FetchContent) FetchContent_Declare( repo-cutlass GIT_REPOSITORY https://github.com/NVIDIA/cutlass - GIT_TAG a49a78ffefc86a87160dfe0ccc3a3a2d1622c918 + GIT_TAG 57e3cfb47a2d9e0d46eb6335c3dc411498efa198 GIT_SHALLOW OFF ) FetchContent_Populate(repo-cutlass) diff --git a/sgl-kernel/csrc/cutlass_extensions/gemm/fp8_blockwise_gemm_sm90_dispatch.cuh b/sgl-kernel/csrc/cutlass_extensions/gemm/fp8_blockwise_gemm_sm90_dispatch.cuh index 8deda43e5..05b70c4f2 100644 --- a/sgl-kernel/csrc/cutlass_extensions/gemm/fp8_blockwise_gemm_sm90_dispatch.cuh +++ b/sgl-kernel/csrc/cutlass_extensions/gemm/fp8_blockwise_gemm_sm90_dispatch.cuh @@ -72,7 +72,7 @@ struct cutlass_3x_gemm_fp8_blockwise { using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT; - using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum; + using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, diff --git a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu index b2e1fc83c..e6a2ccbb9 100644 --- a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu +++ b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu @@ -463,7 +463,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( using MmaTileShape = Shape<_128, _32, _128>; using ClusterShape = Shape<_2, _1, _1>; // TODO: Check Pingpong or Cooperative - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; @@ -475,7 +475,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( using ElementA = cutlass::float_e4m3_t; using MmaTileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>; @@ -487,7 +487,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( using ElementA = cutlass::float_e4m3_t; using MmaTileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_1, _2, _1>; - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;