Add a CUDA kernel for fusing mapping and weighted sum for MoE. (#6916)

Co-authored-by: Elfie Guo <elfiegxf@gmail.com>
This commit is contained in:
Elfie Guo
2025-06-07 15:24:39 -07:00
committed by GitHub
parent 62fec60d81
commit 3e56f557fd
7 changed files with 146 additions and 12 deletions

12
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu Normal file → Executable file
View File

@@ -174,10 +174,10 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
// bool use_small_config = a[0].size(0) <= 128;
struct MmaConfig1 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _32, _128>;
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using MmaTileShape = Shape<_256, _32, _128>;
using ClusterShape = Shape<_2, _1, _1>; // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
using ScaleConfig =
cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
@@ -214,7 +214,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
torch::Tensor scales_a_t = scales_a.t();
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
if (a.size(0) <= 512 && a.size(1) >= 2048) {
if (a.size(0) <= 2048 && a.size(1) >= 2048) {
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
expert_offsets,
a_ptrs,
@@ -247,7 +247,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
expert_offsets,
workspace);
output = output_t.t();
} else if (a.size(0) > 512 && a.size(1) >= 2048) {
} else if (a.size(0) > 2048 && a.size(1) >= 2048) {
run_get_group_gemm_starts<MmaConfig2::LayoutSFA, MmaConfig2::LayoutSFB, MmaConfig2::ScaleConfig>(
expert_offsets,
a_ptrs,