Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com> Co-authored-by: PGFLMG <1106310035@qq.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
175 lines
7.4 KiB
Plaintext
175 lines
7.4 KiB
Plaintext
#pragma once
|
|
|
|
// Misc
|
|
#include "cute/tensor.hpp"
|
|
#include "cutlass/arch/arch.h"
|
|
#include "cutlass/arch/mma.h"
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/detail/blockwise_scale_layout.hpp"
|
|
#include "cutlass/epilogue/dispatch_policy.hpp"
|
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
|
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
|
#include "cutlass/layout/layout.h"
|
|
#include "cutlass/numeric_conversion.h"
|
|
#include "cutlass/numeric_size.h"
|
|
|
|
// Collective Builder
|
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
|
#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
|
|
#include "cutlass/epilogue/thread/activation.h"
|
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
|
|
|
// Integration
|
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
|
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
|
|
|
namespace expert_specialization {
|
|
|
|
using namespace cute;
|
|
|
|
struct PerfConfigLowMH20 {
|
|
// Swap A/B
|
|
using ElementA = cutlass::float_e4m3_t;
|
|
using MmaTileShape = Shape<_128, _32, _128>;
|
|
using ClusterShape = Shape<_2, _1, _1>;
|
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
|
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
|
using ScaleConfig =
|
|
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
|
};
|
|
|
|
struct PerfConfigLowMHx00 {
|
|
// Swap A/B
|
|
using ElementA = cutlass::float_e4m3_t;
|
|
using MmaTileShape = Shape<_256, _32, _128>;
|
|
using ClusterShape = Shape<_2, _1, _1>;
|
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
|
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
|
using ScaleConfig =
|
|
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
|
};
|
|
|
|
struct PerfConfigMiddleMH20 {
|
|
using ElementA = cutlass::float_e4m3_t;
|
|
using MmaTileShape = Shape<_64, _128, _128>;
|
|
using ClusterShape = Shape<_1, _2, _1>;
|
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
|
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
|
using ScaleConfig =
|
|
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
|
};
|
|
|
|
struct PerfConfigMiddleMHx00 {
|
|
using ElementA = cutlass::float_e4m3_t;
|
|
using MmaTileShape = Shape<_256, _64, _128>;
|
|
using ClusterShape = Shape<_2, _1, _1>;
|
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
|
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
|
using ScaleConfig =
|
|
cutlass::detail::Sm90BlockwiseScaleConfig<128, 1, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
|
};
|
|
|
|
struct PerfConfigHighMH20 {
|
|
using ElementA = cutlass::float_e4m3_t;
|
|
using MmaTileShape = Shape<_64, _128, _128>;
|
|
using ClusterShape = Shape<_2, _1, _1>;
|
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise;
|
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
|
using ScaleConfig =
|
|
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
|
};
|
|
|
|
struct PerfConfigHighMHx00 {
|
|
using ElementA = cutlass::float_e4m3_t;
|
|
using MmaTileShape = Shape<_128, _128, _128>;
|
|
using ClusterShape = Shape<_1, _2, _1>;
|
|
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise;
|
|
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
|
|
using ScaleConfig =
|
|
cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128, cute::GMMA::Major::K, cute::GMMA::Major::K>;
|
|
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
|
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
|
};
|
|
|
|
template <typename OutType, typename LayoutD, typename PerfConfig>
|
|
struct ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits {
|
|
using ElementA = cutlass::float_e4m3_t;
|
|
using ElementB = cutlass::float_e4m3_t;
|
|
using ElementC = void;
|
|
using ElementD = OutType;
|
|
using ElementAccumulator = float;
|
|
using LayoutA = cutlass::layout::RowMajor;
|
|
using LayoutB = cutlass::layout::ColumnMajor;
|
|
using LayoutC = LayoutD;
|
|
using LayoutSFA = typename PerfConfig::LayoutSFA;
|
|
using LayoutSFB = typename PerfConfig::LayoutSFB;
|
|
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
|
|
|
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
|
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
|
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
|
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
|
|
|
using ArchTag = cutlass::arch::Sm90;
|
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
|
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
|
using CustomEVTIdentity = // acc
|
|
cutlass::epilogue::fusion::Sm90EVT<
|
|
cutlass::epilogue::fusion::
|
|
Sm90Compute<cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator, RoundStyle>,
|
|
cutlass::epilogue::fusion::Sm90AccFetch>;
|
|
|
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
|
ArchTag,
|
|
OperatorClass,
|
|
typename PerfConfig::MmaTileShape,
|
|
typename PerfConfig::ClusterShape,
|
|
cutlass::epilogue::collective::EpilogueTileAuto,
|
|
ElementAccumulator,
|
|
ElementAccumulator,
|
|
ElementC, // Use void to avoid load Matrix C
|
|
LayoutC*,
|
|
AlignmentC,
|
|
ElementD,
|
|
LayoutD*,
|
|
AlignmentD,
|
|
typename PerfConfig::EpilogueSchedule,
|
|
CustomEVTIdentity>::CollectiveOp;
|
|
|
|
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
|
ArchTag,
|
|
OperatorClass,
|
|
ElementA,
|
|
cute::tuple<LayoutA*, typename PerfConfig::LayoutSFA*>,
|
|
AlignmentA,
|
|
ElementB,
|
|
cute::tuple<LayoutB*, typename PerfConfig::LayoutSFB*>,
|
|
AlignmentB,
|
|
ElementAccumulator,
|
|
typename PerfConfig::MmaTileShape,
|
|
typename PerfConfig::ClusterShape,
|
|
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
|
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
|
typename PerfConfig::KernelSchedule>::CollectiveOp;
|
|
|
|
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue, void>;
|
|
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
|
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
|
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
|
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
|
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
|
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
|
};
|
|
|
|
} // namespace expert_specialization
|