Add a CUDA kernel for fusing mapping and weighted sum for MoE. (#6916)
Co-authored-by: Elfie Guo <elfiegxf@gmail.com>
This commit is contained in:
12
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
Normal file → Executable file
12
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
Normal file → Executable file
@@ -174,10 +174,10 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
// bool use_small_config = a[0].size(0) <= 128;
|
||||
struct MmaConfig1 {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using MmaTileShape = Shape<_128, _32, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using MmaTileShape = Shape<_256, _32, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>; // Layout type for SFB matrix operand
|
||||
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
|
||||
using ScaleConfig =
|
||||
cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
@@ -214,7 +214,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
torch::Tensor scales_a_t = scales_a.t();
|
||||
torch::Tensor scales_b_t = scales_b.transpose(1, 2);
|
||||
|
||||
if (a.size(0) <= 512 && a.size(1) >= 2048) {
|
||||
if (a.size(0) <= 2048 && a.size(1) >= 2048) {
|
||||
run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
@@ -247,7 +247,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
|
||||
expert_offsets,
|
||||
workspace);
|
||||
output = output_t.t();
|
||||
} else if (a.size(0) > 512 && a.size(1) >= 2048) {
|
||||
} else if (a.size(0) > 2048 && a.size(1) >= 2048) {
|
||||
run_get_group_gemm_starts<MmaConfig2::LayoutSFA, MmaConfig2::LayoutSFB, MmaConfig2::ScaleConfig>(
|
||||
expert_offsets,
|
||||
a_ptrs,
|
||||
|
||||
Reference in New Issue
Block a user