From 3e56f557fdd107c678e3ea563c01baed0f6ef80b Mon Sep 17 00:00:00 2001 From: Elfie Guo <164945471+elfiegg@users.noreply.github.com> Date: Sat, 7 Jun 2025 15:24:39 -0700 Subject: [PATCH] Add a CUDA kernel for fusing mapping and weighted sum for MoE. (#6916) Co-authored-by: Elfie Guo --- python/sglang/srt/layers/moe/cutlass_moe.py | 11 +- sgl-kernel/csrc/common_extension.cc | 3 +- .../csrc/moe/fp8_blockwise_moe_kernel.cu | 12 +- sgl-kernel/csrc/moe/prepare_moe_input.cu | 114 ++++++++++++++++++ sgl-kernel/include/sgl_kernel_ops.h | 6 + sgl-kernel/python/sgl_kernel/__init__.py | 1 + sgl-kernel/python/sgl_kernel/moe.py | 11 ++ 7 files changed, 146 insertions(+), 12 deletions(-) mode change 100644 => 100755 sgl-kernel/csrc/common_extension.cc mode change 100644 => 100755 sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu mode change 100644 => 100755 sgl-kernel/csrc/moe/prepare_moe_input.cu mode change 100644 => 100755 sgl-kernel/include/sgl_kernel_ops.h diff --git a/python/sglang/srt/layers/moe/cutlass_moe.py b/python/sglang/srt/layers/moe/cutlass_moe.py index 5b90c4fb4..00b7adf77 100755 --- a/python/sglang/srt/layers/moe/cutlass_moe.py +++ b/python/sglang/srt/layers/moe/cutlass_moe.py @@ -15,6 +15,7 @@ _is_cuda = is_cuda() if _is_cuda: import sgl_kernel from sgl_kernel import ( + apply_shuffle_mul_sum, cutlass_fp4_group_mm, fp8_blockwise_scaled_grouped_mm, prepare_moe_input, @@ -151,8 +152,8 @@ def cutlass_fused_experts_fp8( k, ) - rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype) - rep_a1_scales = a1_scale[a_map] + rep_a_q = shuffle_rows(a_q, a_map, (m * topk, k)) + rep_a1_scales = shuffle_rows(a1_scale, a_map, (m * topk, int(k / 128))) c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype) c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype) @@ -206,9 +207,9 @@ def cutlass_fused_experts_fp8( expert_offsets[:-1], workspace, ) - return ( - c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype) - ).sum(dim=1) + + result = torch.empty((m, k), device=device, dtype=out_dtype) + return apply_shuffle_mul_sum(c2, result, c_map, topk_weights) FLOAT4_E2M1_MAX = 6.0 diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc old mode 100644 new mode 100755 index 29f9a7605..1bc227197 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -195,7 +195,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("shuffle_rows(Tensor input, Tensor dst2src_map, Tensor output) -> ()"); m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows); - + m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"); + m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum); /* * From csrc/speculative */ diff --git a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu old mode 100644 new mode 100755 index b51849234..e3e170e47 --- a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu +++ b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu @@ -174,10 +174,10 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( // bool use_small_config = a[0].size(0) <= 128; struct MmaConfig1 { using ElementA = cutlass::float_e4m3_t; - using MmaTileShape = Shape<_128, _32, _128>; - using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand - using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100; - using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using MmaTileShape = Shape<_256, _32, _128>; + using ClusterShape = Shape<_2, _1, _1>; // Layout type for SFB matrix operand + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>; using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); @@ -214,7 +214,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( torch::Tensor scales_a_t = scales_a.t(); torch::Tensor scales_b_t = scales_b.transpose(1, 2); - if (a.size(0) <= 512 && a.size(1) >= 2048) { + if (a.size(0) <= 2048 && a.size(1) >= 2048) { run_get_group_gemm_starts( expert_offsets, a_ptrs, @@ -247,7 +247,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( expert_offsets, workspace); output = output_t.t(); - } else if (a.size(0) > 512 && a.size(1) >= 2048) { + } else if (a.size(0) > 2048 && a.size(1) >= 2048) { run_get_group_gemm_starts( expert_offsets, a_ptrs, diff --git a/sgl-kernel/csrc/moe/prepare_moe_input.cu b/sgl-kernel/csrc/moe/prepare_moe_input.cu old mode 100644 new mode 100755 index 06237b56e..0eeec4c75 --- a/sgl-kernel/csrc/moe/prepare_moe_input.cu +++ b/sgl-kernel/csrc/moe/prepare_moe_input.cu @@ -252,3 +252,117 @@ void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2sr shuffle_rows_caller(input_tensor, dst2src_map, output_tensor); return; } + +template +__global__ void apply_shuffle_mul_sum_kernel( + const scalar_t* __restrict__ input_tensor, // [m * topk, row_stride] + scalar_t* __restrict__ output_tensor, // [m, row_stride] + const int32_t* __restrict__ permutation, // [m * topk] + int m, + int topk, + int row_stride, + const scalar_t* __restrict__ factors) // [m * topk] or nullptr +{ + int i = blockIdx.x; // [0, m * topk) + int d = threadIdx.x; // [0, row_stride) + + if (i >= m || d >= row_stride) return; + + scalar_t sum_val = 0.0; + + for (int j = 0; j < topk; ++j) { + int index_2d = i * topk + j; + int src_row = permutation[index_2d]; + if (src_row >= m) continue; + + scalar_t val = input_tensor[src_row * row_stride + d]; + + scalar_t factor = 1.0; + if (factors != nullptr) { + factor = factors[index_2d]; + } + + sum_val += factor * val; + } + + output_tensor[i * row_stride + d] = sum_val; +} + +void get_apply_shuffle_mul_sum_caller( + const torch::Tensor& input_tensor, // [m * topk, row_stride], bf16/f16 + torch::Tensor& output_tensor, // [m, row_stride], bf16/f16 + const torch::Tensor& permutation, // [m * topk], int32 + const std::optional& factors_opt) // optional [m * topk], bf16/f16 +{ + TORCH_CHECK(input_tensor.dim() == 2, "input_tensor must be 2D [m * topk, row_stride]"); + TORCH_CHECK(output_tensor.dim() == 2, "output_tensor must be 2D [m, row_stride]"); + TORCH_CHECK(permutation.dim() == 1, "permutation must be 1D [m * topk]"); + + int m = output_tensor.size(0); + int topk = int(permutation.size(0) / m); + int row_stride = output_tensor.size(1); + + TORCH_CHECK(permutation.size(0) == m * topk, "permutation size must match m * topk"); + + dim3 block(std::min(256, row_stride)); + dim3 grid(m); // blockIdx.x = j, blockIdx.y = i + auto stream = at::cuda::getCurrentCUDAStream(input_tensor.device().index()); + + const int32_t* perm_ptr = permutation.data_ptr(); + + void* factors_ptr = nullptr; + if (factors_opt.has_value()) { + TORCH_CHECK(factors_opt->dtype() == output_tensor.dtype(), "Factors must match output dtype"); + TORCH_CHECK(factors_opt->numel() == m * topk, "Factors must have shape [m * topk]"); + factors_ptr = factors_opt->data_ptr(); + } + + if (output_tensor.scalar_type() == at::ScalarType::Half) { + const at::Half* factor_data = static_cast(factors_ptr); + apply_shuffle_mul_sum_kernel<<>>( + input_tensor.data_ptr(), + output_tensor.data_ptr(), + perm_ptr, + m, + topk, + row_stride, + static_cast(factors_ptr)); + } else if (output_tensor.scalar_type() == at::ScalarType::BFloat16) { + const c10::BFloat16* factor_data = static_cast(factors_ptr); + apply_shuffle_mul_sum_kernel<<>>( + input_tensor.data_ptr(), + output_tensor.data_ptr(), + perm_ptr, + m, + topk, + row_stride, + static_cast(factors_ptr)); + } else { + TORCH_CHECK(false, "Unsupported output dtype for cast+mul kernel: ", output_tensor.scalar_type()); + } +} + +/** + * @brief Applies a permutation-based shuffle, element-wise multiplication, and reduction over the second dimension. + * + * This function performs the equivalent of the following PyTorch expression: + * + * (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) + * + * Specifically: + * - `input` is shuffled using the `permutation` tensor. + * - The shuffled tensor is reshaped and multiplied element-wise with `factors` (e.g., top-k weights). + * - The result is summed along dimension 1 (the top-k dimension), and stored in `output`. + * + * @param input Input tensor of shape (m * topk, k), representing c2. + * @param output Output tensor of shape (m, k), where the final reduced results are stored. + * @param permutation Index tensor (e.g., c_map) that maps positions in `input` to shuffled layout. + * @param factors Optional scaling factors (e.g., top-k weights), shape (m * topk) or (m, topk). + */ +void apply_shuffle_mul_sum( + const torch::Tensor& input, + torch::Tensor& output, + const torch::Tensor& permutation, + const std::optional& factors) { + get_apply_shuffle_mul_sum_caller(input, output, permutation, factors); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h old mode 100644 new mode 100755 index 586f7cafe..dbd969ea4 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -276,6 +276,12 @@ void ep_moe_post_reorder( void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor); +void apply_shuffle_mul_sum( + const torch::Tensor& input, + torch::Tensor& output, + const torch::Tensor& permutation, + const std::optional& factors); + void cutlass_fp4_group_mm( torch::Tensor& output, const torch::Tensor& a, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 9aef5a2b0..d9ce97417 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -48,6 +48,7 @@ from sgl_kernel.gemm import ( ) from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda from sgl_kernel.moe import ( + apply_shuffle_mul_sum, cutlass_fp4_group_mm, ep_moe_post_reorder, ep_moe_pre_reorder, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index b3497f517..176c979a9 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -178,6 +178,17 @@ def prepare_moe_input( ) +def apply_shuffle_mul_sum( + input, + output, + permutation, + factors, +): + torch.ops.sgl_kernel.apply_shuffle_mul_sum.default( + input, output, permutation, factors + ) + + def cutlass_fp4_group_mm( a_fp4, b_fp4,