Files
sglang/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_traits.cuh
Qi Yuhang 9a30914e94 [sgl-kernel][1/N]Support Expert Specialization Grouped GEMM (#11432)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: PGFLMG <1106310035@qq.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
2025-10-12 20:19:21 -07:00

175 lines
7.4 KiB
Plaintext

#pragma once
// Misc
#include "cute/tensor.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/cutlass.h"
#include "cutlass/detail/blockwise_scale_layout.hpp"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/layout/layout.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_size.h"
// Collective Builder
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
// Integration
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
namespace expert_specialization {
using namespace cute;
struct PerfConfigLowMH20 {
// Swap A/B
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _32, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct PerfConfigLowMHx00 {
// Swap A/B
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_256, _32, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct PerfConfigMiddleMH20 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct PerfConfigMiddleMHx00 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_256, _64, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct PerfConfigHighMH20 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct PerfConfigHighMHx00 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using ScaleConfig =
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
template <typename OutType, typename LayoutD, typename PerfConfig>
struct ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits {
using ElementA = cutlass::float_e4m3_t;
using ElementB = cutlass::float_e4m3_t;
using ElementC = void;
using ElementD = OutType;
using ElementAccumulator = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = LayoutD;
using LayoutSFA = typename PerfConfig::LayoutSFA;
using LayoutSFB = typename PerfConfig::LayoutSFB;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using CustomEVTIdentity = // acc
cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::
Sm90Compute<cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator, RoundStyle>,
cutlass::epilogue::fusion::Sm90AccFetch>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
typename PerfConfig::MmaTileShape,
typename PerfConfig::ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementAccumulator,
ElementC, // Use void to avoid load Matrix C
LayoutC*,
AlignmentC,
ElementD,
LayoutD*,
AlignmentD,
typename PerfConfig::EpilogueSchedule,
CustomEVTIdentity>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA*, typename PerfConfig::LayoutSFA*>,
AlignmentA,
ElementB,
cute::tuple<LayoutB*, typename PerfConfig::LayoutSFB*>,
AlignmentB,
ElementAccumulator,
typename PerfConfig::MmaTileShape,
typename PerfConfig::ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename PerfConfig::KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
};
} // namespace expert_specialization