[sgl-kernel][1/N]Support Expert Specialization Grouped GEMM (#11432)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: PGFLMG <1106310035@qq.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,174 @@
|
||||
#pragma once
|
||||
|
||||
// Misc
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/detail/blockwise_scale_layout.hpp"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/layout/layout.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/numeric_size.h"
|
||||
|
||||
// Collective Builder
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
|
||||
#include "cutlass/epilogue/thread/activation.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
// Integration
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
namespace expert_specialization {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct PerfConfigLowMH20 {
|
||||
// Swap A/B
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _32, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct PerfConfigLowMHx00 {
|
||||
// Swap A/B
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_256, _32, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct PerfConfigMiddleMH20 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct PerfConfigMiddleMHx00 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_256, _64, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct PerfConfigHighMH20 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
struct PerfConfigHighMHx00 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
};
|
||||
|
||||
template <typename OutType, typename LayoutD, typename PerfConfig>
|
||||
struct ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = void;
|
||||
using ElementD = OutType;
|
||||
using ElementAccumulator = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = LayoutD;
|
||||
using LayoutSFA = typename PerfConfig::LayoutSFA;
|
||||
using LayoutSFB = typename PerfConfig::LayoutSFB;
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using CustomEVTIdentity = // acc
|
||||
cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::
|
||||
Sm90Compute<cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator, RoundStyle>,
|
||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
typename PerfConfig::MmaTileShape,
|
||||
typename PerfConfig::ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementAccumulator,
|
||||
ElementC, // Use void to avoid load Matrix C
|
||||
LayoutC*,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD*,
|
||||
AlignmentD,
|
||||
typename PerfConfig::EpilogueSchedule,
|
||||
CustomEVTIdentity>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA*, typename PerfConfig::LayoutSFA*>,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB*, typename PerfConfig::LayoutSFB*>,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
typename PerfConfig::MmaTileShape,
|
||||
typename PerfConfig::ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename PerfConfig::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
};
|
||||
|
||||
} // namespace expert_specialization
|
||||
Reference in New Issue
Block a user