From 616a3e20df5faf73a7d1581c7016413f64d583e4 Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Fri, 19 Sep 2025 14:12:09 +0800 Subject: [PATCH] [sgl-kernel] Support moe_sum_reduce cuda kernel (#10321) Co-authored-by: luoyuan.luo Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> --- .../fused_moe_triton/benchmark_sum_scale.py | 35 +- sgl-kernel/CMakeLists.txt | 1 + sgl-kernel/csrc/common_extension.cc | 2 + sgl-kernel/csrc/moe/moe_sum_reduce.cu | 303 ++++++++++++++++++ sgl-kernel/include/sgl_kernel_ops.h | 2 + sgl-kernel/python/sgl_kernel/__init__.py | 1 + sgl-kernel/python/sgl_kernel/moe.py | 12 + 7 files changed, 346 insertions(+), 10 deletions(-) create mode 100644 sgl-kernel/csrc/moe/moe_sum_reduce.cu diff --git a/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py b/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py index 979d2bbd1..ec6b2f2f2 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_sum_scale.py @@ -1,6 +1,7 @@ import torch import triton import triton.language as tl +from sgl_kernel import moe_sum_reduce as moe_sum_reduce_cuda from triton.testing import do_bench @@ -57,7 +58,7 @@ def _moe_sum_reduce_kernel( # _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py -def moe_sum_reduce( +def moe_sum_reduce_triton( input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float ): assert input.is_contiguous() @@ -117,9 +118,9 @@ def get_benchmark(): x_names=["num_tokens"], x_vals=num_tokens_range, line_arg="version", - line_vals=["baseline", "compiled", "triton"], - line_names=["Original", "TorchCompile", "TritonKernel"], - styles=[("blue", "-"), ("green", "-"), ("red", "-")], + line_vals=["baseline", "compiled", "triton", "cuda"], + line_names=["Original", "TorchCompile", "TritonKernel", "CudaKernel"], + styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")], ylabel="us", plot_name="sum_scaled_performance", args={}, @@ -140,8 +141,10 @@ def get_benchmark(): compute_sum_scaled_baseline(x, out, scaling_factor) elif version == "compiled": compute_sum_scaled_compiled(x, out, scaling_factor) + elif version == "triton": + moe_sum_reduce_triton(x, out, scaling_factor) else: - moe_sum_reduce(x, out, scaling_factor) + moe_sum_reduce_cuda(x, out, scaling_factor) # Benchmark quantiles = [0.5, 0.2, 0.8] @@ -155,9 +158,15 @@ def get_benchmark(): lambda: compute_sum_scaled_compiled(x, out, scaling_factor), quantiles=quantiles, ) + elif version == "triton": + ms, min_ms, max_ms = do_bench( + lambda: moe_sum_reduce_triton(x, out, scaling_factor), + quantiles=quantiles, + ) else: ms, min_ms, max_ms = do_bench( - lambda: moe_sum_reduce(x, out, scaling_factor), quantiles=quantiles + lambda: moe_sum_reduce_cuda(x, out, scaling_factor), + quantiles=quantiles, ) return 1000 * ms, 1000 * max_ms, 1000 * min_ms @@ -176,11 +185,16 @@ def verify_correctness(num_tokens=1024): compute_sum_scaled_compiled(x, out_compiled, scaling_factor) out_triton = torch.empty_like(out_baseline) - moe_sum_reduce(x, out_triton, scaling_factor) + moe_sum_reduce_triton(x, out_triton, scaling_factor) - if torch.allclose( - out_baseline, out_compiled, atol=1e-2, rtol=1e-2 - ) and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2): + out_cuda = torch.empty_like(out_baseline) + moe_sum_reduce_cuda(x, out_cuda, scaling_factor) + + if ( + torch.allclose(out_baseline, out_compiled, atol=1e-2, rtol=1e-2) + and torch.allclose(out_baseline, out_triton, atol=1e-2, rtol=1e-2) + and torch.allclose(out_baseline, out_cuda, atol=1e-2, rtol=1e-2) + ): print("✅ All implementations match") else: print("❌ Implementations differ") @@ -188,6 +202,7 @@ def verify_correctness(num_tokens=1024): f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}" ) print(f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}") + print(f"Baseline vs Cuda: {(out_baseline - out_cuda).abs().max().item()}") if __name__ == "__main__": diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 9b73c048a..5a5c3ef39 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -309,6 +309,7 @@ set(SOURCES "csrc/moe/marlin_moe_wna16/ops.cu" "csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_fused_gate.cu" + "csrc/moe/moe_sum_reduce.cu" "csrc/moe/moe_topk_softmax_kernels.cu" "csrc/moe/nvfp4_blockwise_moe.cu" "csrc/moe/fp8_blockwise_moe_kernel.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 21f3763f6..4b99f7645 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -217,6 +217,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + m.def("moe_sum_reduce(Tensor input, Tensor output, float routed_scaling_factor) -> ()"); + m.impl("moe_sum_reduce", torch::kCUDA, &moe_sum_reduce); m.def( "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int " "num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> " diff --git a/sgl-kernel/csrc/moe/moe_sum_reduce.cu b/sgl-kernel/csrc/moe/moe_sum_reduce.cu new file mode 100644 index 000000000..6e5454336 --- /dev/null +++ b/sgl-kernel/csrc/moe/moe_sum_reduce.cu @@ -0,0 +1,303 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "cutlass/array.h" +#include "utils.h" + +template +__device__ __forceinline__ float to_float(T x) { + return static_cast(x); +} + +template <> +__device__ __forceinline__ float to_float(half x) { + return __half2float(x); +} + +template +__device__ __forceinline__ T from_float(float x) { + return static_cast(x); +} + +template <> +__device__ __forceinline__ half from_float(float x) { + return __float2half_rn(x); +} + +template +__device__ __forceinline__ T ldg_cg(const T* p) { + return __ldg(p); +} + +union Pack16B { + uint4 v; + __nv_bfloat16 u16[8]; +}; + +template +__global__ void moe_sum_reduce_warp_per_token_vec_kernel( + const at::BFloat16* __restrict__ x, + at::BFloat16* __restrict__ y, + const int64_t token_num, + const int64_t hidden_dim, + const int64_t topk_num, + const int64_t stride_token, // in elements + const int64_t stride_topk, // in elements + const int64_t out_stride_token, // in elements + const float scale) { + constexpr int VEC = 16; + constexpr int PACKS = VEC / 8; + + const int warp_id = threadIdx.x / 32; + const int lane = threadIdx.x % 32; + const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id; + if (t >= token_num) return; + + const int64_t n_chunks = hidden_dim / VEC; + + for (int64_t chunk = (int64_t)blockIdx.x * 32 + lane; chunk < n_chunks; chunk += (int64_t)gridDim.x * 32) { + const int64_t d = chunk * VEC; + const int64_t base = t * stride_token + d; + + float acc[VEC]; +#pragma unroll + for (int i = 0; i < VEC; ++i) + acc[i] = 0.f; + +#pragma unroll + for (int k = 0; k < topk_num; ++k) { +#pragma unroll + for (int p = 0; p < PACKS; ++p) { + const int64_t offset = base + (int64_t)k * stride_topk + p * 8; + Pack16B pack = {ldg_cg(reinterpret_cast(x + offset))}; + +#pragma unroll + for (int i = 0; i < 8; ++i) { + acc[p * 8 + i] += __bfloat162float(pack.u16[i]); + } + } + } + +#pragma unroll + for (int i = 0; i < VEC; ++i) + acc[i] *= scale; + +#pragma unroll + for (int p = 0; p < PACKS; ++p) { + Pack16B outp; +#pragma unroll + for (int i = 0; i < 8; ++i) { + outp.u16[i] = __float2bfloat16_rn(acc[p * 8 + i]); + } + const int64_t dst = t * out_stride_token + d + p * 8; + *reinterpret_cast(y + dst) = outp.v; + } + } +} + +template +__global__ void moe_sum_reduce_kernel_warp_token_topk( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ y, + const int64_t token_num, + const int64_t hidden_dim, + const int64_t stride_token, + const int64_t stride_topk, + const int64_t out_stride_token, + const float scale) { + const int warp_id = threadIdx.x / 32; + const int lane = threadIdx.x % 32; + const int64_t t = (int64_t)blockIdx.y * WARPS_PER_BLOCK + warp_id; + if (t >= token_num) return; + + for (int64_t d = (int64_t)blockIdx.x * 32 + lane; d < hidden_dim; d += (int64_t)gridDim.x * 32) { + float acc = 0.f; + const int64_t base = t * stride_token + d; + +#pragma unroll + for (int k = 0; k < TOPK; ++k) { + acc += to_float(ldg_cg(&x[base + (int64_t)k * stride_topk])); + } + acc *= scale; + y[t * out_stride_token + d] = from_float(acc); + } +} + +template +__global__ void moe_sum_reduce_kernel( + const scalar_t* __restrict__ x, + scalar_t* __restrict__ y, + const int64_t token_num, + const int64_t hidden_dim, + const int64_t stride_token, + const int64_t stride_topk, + const int64_t out_stride_token, + const float scale) { + for (int t = blockIdx.y; t < token_num; t += gridDim.y) { + for (int d = blockIdx.x * blockDim.x + threadIdx.x; d < hidden_dim; d += blockDim.x * gridDim.x) { + const int64_t base = t * stride_token + d; + float acc = 0.f; + +#pragma unroll + for (int k = 0; k < TOPK; ++k) { + acc += to_float(x[base + (int64_t)k * stride_topk]); + } + + acc *= scale; + y[t * out_stride_token + d] = from_float(acc); + } + } +} + +void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor) { + TORCH_CHECK(input.is_cuda(), "input must be CUDA tensor"); + TORCH_CHECK(output.is_cuda(), "output must be CUDA tensor"); + TORCH_CHECK(input.dim() == 3, "input must be a 3D tensor like [token_num, topk_num, hidden_dim]"); + TORCH_CHECK(output.dim() == 2, "output must be [token_num, hidden_dim]"); + TORCH_CHECK(input.size(0) == output.size(0), "token dim mismatch"); + TORCH_CHECK(input.size(2) == output.size(1), "hidden_dim mismatch"); + + TORCH_CHECK(input.is_contiguous(), "expect input to be contiguous"); + TORCH_CHECK(output.is_contiguous(), "expect output to be contiguous"); + + const int64_t token_num = input.size(0); + const int64_t topk_num = input.size(1); + const int64_t hidden_dim = input.size(2); + + const int64_t in_stride_token = input.stride(0); + const int64_t in_stride_topk = input.stride(1); + const int64_t out_stride_token = output.stride(0); + + const float scale = static_cast(routed_scaling_factor); + + auto stream = at::cuda::getCurrentCUDAStream(); + + const bool fast_bf16_vec_ok = (input.scalar_type() == at::kBFloat16) && (token_num > 256) && (hidden_dim % 8 == 0); + + // Fast path for bf16 vectorize + if (fast_bf16_vec_ok) { + constexpr int WARPS_PER_BLOCK = 8; + constexpr int THREADS = WARPS_PER_BLOCK * 32; + + const int64_t n_chunks = hidden_dim / 8; + int64_t grid_x = (n_chunks + 32 - 1) / 32; + if (grid_x > 65535) grid_x = 65535; + + int64_t grid_y = (token_num + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK; + if (grid_y > 65535) grid_y = 65535; + + dim3 block(THREADS); + dim3 grid(static_cast(grid_x), static_cast(grid_y)); + + auto stream = at::cuda::getCurrentCUDAStream(); + + moe_sum_reduce_warp_per_token_vec_kernel<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + token_num, + hidden_dim, + topk_num, + in_stride_token, + in_stride_topk, + out_stride_token, + scale); + + TORCH_CHECK(cudaGetLastError() == cudaSuccess, "moe_sum_reduce CUDA kernel launch failed"); + return; + } + + const bool per_token_use_one_warp = (token_num > 128); + + auto dispatch_topk = [&](auto&& launch_kernel) { + switch (topk_num) { + case 2: + launch_kernel(std::integral_constant{}); + break; + case 4: + launch_kernel(std::integral_constant{}); + break; + case 8: + launch_kernel(std::integral_constant{}); + break; + case 9: + launch_kernel(std::integral_constant{}); + break; + default: + launch_kernel(std::integral_constant{}); + break; + } + }; + + if (!per_token_use_one_warp) { + // ---------- small-token ---------- + const int block_size = 256; + int64_t grid_x = (hidden_dim + block_size - 1) / block_size; + grid_x = grid_x > 65535 ? 65535 : grid_x; + int64_t grid_y = token_num < 65535 ? token_num : 65535; + + dim3 block(block_size); + dim3 grid(static_cast(grid_x), static_cast(grid_y)); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_small_token", [&] { + using scalar_t_ = scalar_t; + + auto lauch_small_token_kernel = [&](auto topk_c) { + constexpr int TK = decltype(topk_c)::value; + + moe_sum_reduce_kernel<<>>( + input.data_ptr(), + output.data_ptr(), + token_num, + hidden_dim, + in_stride_token, + in_stride_topk, + out_stride_token, + scale); + }; + dispatch_topk(lauch_small_token_kernel); + }); + + } else { + // ---------- warp-token ---------- + constexpr int WARPS_PER_BLOCK = 4; + constexpr int THREADS = WARPS_PER_BLOCK * 32; + + int64_t gx = (hidden_dim + 32 - 1) / 32; + gx = gx > 65535 ? 65535 : gx; + + int64_t gy = (token_num + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK; + gy = gy > 65535 ? 65535 : gy; + + dim3 block(THREADS); + dim3 grid(static_cast(gx), static_cast(gy)); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kHalf, at::kBFloat16, input.scalar_type(), "moe_sum_reduce_cuda_large_token", [&] { + using scalar_t_ = scalar_t; + + auto launch_large_token_kernel = [&](auto topk_c) { + constexpr int TK = decltype(topk_c)::value; + + moe_sum_reduce_kernel_warp_token_topk<<>>( + input.data_ptr(), + output.data_ptr(), + token_num, + hidden_dim, + in_stride_token, + in_stride_topk, + out_stride_token, + scale); + }; + dispatch_topk(launch_large_token_kernel); + }); + } + TORCH_CHECK(cudaGetLastError() == cudaSuccess, "CUDA kernel launch failed"); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 5829a72e4..fdaba4c93 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -293,6 +293,8 @@ void moe_align_block_size( void topk_softmax( torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& gating_output, bool renormalize); +void moe_sum_reduce(at::Tensor& input, at::Tensor& output, double routed_scaling_factor); + std::vector moe_fused_gate( at::Tensor& input, at::Tensor& bias, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index e6f8c0dc6..49a97bccc 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -112,6 +112,7 @@ from sgl_kernel.moe import ( fp8_blockwise_scaled_grouped_mm, moe_align_block_size, moe_fused_gate, + moe_sum_reduce, prepare_moe_input, topk_softmax, ) diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index 66fec9f2b..584722b32 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -36,6 +36,18 @@ def topk_softmax( ) +def moe_sum_reduce( + input_tensor, + output_tensor, + routed_scaling_factor=0, +): + torch.ops.sgl_kernel.moe_sum_reduce.default( + input_tensor, + output_tensor, + routed_scaling_factor, + ) + + def moe_fused_gate( input_tensor, bias,