From 43baba649e43d198eea73f58cd50882530763ebc Mon Sep 17 00:00:00 2001 From: Yuan Luo Date: Thu, 5 Jun 2025 15:33:47 +0800 Subject: [PATCH] [EP] Add cuda kernel for moe_ep_post_reorder (#6837) Co-authored-by: luoyuan.luo --- .../benchmark/bench_moe_ep_post_reorder.py | 92 ++++++++++ sgl-kernel/csrc/common_extension.cc | 8 +- sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu | 85 ++++++++- sgl-kernel/include/sgl_kernel_ops.h | 10 ++ sgl-kernel/python/sgl_kernel/__init__.py | 1 + sgl-kernel/python/sgl_kernel/moe.py | 22 +++ .../tests/test_ep_moe_post_reorder_kernel.py | 163 ++++++++++++++++++ 7 files changed, 377 insertions(+), 4 deletions(-) create mode 100644 sgl-kernel/benchmark/bench_moe_ep_post_reorder.py create mode 100644 sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py diff --git a/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py b/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py new file mode 100644 index 000000000..701fb8c5b --- /dev/null +++ b/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py @@ -0,0 +1,92 @@ +import torch +import triton +from sgl_kernel import ep_moe_post_reorder + +from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel + +batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096] +configs = [(bs,) for bs in batch_sizes] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["cuda", "triton"], + line_names=["CUDA Kernel", "Triton Kernel"], + styles=[("green", "-"), ("orange", "-")], + ylabel="us", + plot_name="ep-moe-post-reorder-performance", + args={}, + ) +) +def benchmark(batch_size, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512 + + def alloc_tensors(): + down_output = torch.randn( + batch_size * topk, hidden_size, dtype=dtype, device=device + ) + output = torch.zeros(batch_size, hidden_size, dtype=dtype, device=device) + src2dst = torch.randint( + 0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device + ) + topk_ids = torch.randint( + start_expert_id, + end_expert_id + 1, + (batch_size, topk), + dtype=torch.int32, + device=device, + ) + topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device) + return down_output, output, src2dst, topk_ids, topk_weights + + quantiles = [0.5, 0.2, 0.8] + + if provider == "cuda": + d_out, out, s2d, tk_ids, tk_weights = alloc_tensors() + + def run_cuda(): + ep_moe_post_reorder( + d_out, + out, + s2d, + tk_ids, + tk_weights, + start_expert_id, + end_expert_id, + topk, + ) + + ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles) + + elif provider == "triton": + d_out, out, s2d, tk_ids, tk_weights = alloc_tensors() + + def run_triton(): + post_reorder_triton_kernel[(batch_size,)]( + d_out.view(-1), + out.view(-1), + s2d.view(-1), + tk_ids.view(-1), + tk_weights.view(-1), + start_expert_id, + end_expert_id, + topk, + hidden_size, + block_size, + ) + + ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles) + + else: + raise ValueError(f"Unknown provider: {provider}") + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 73e937ec1..29f9a7605 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -174,9 +174,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "(Tensor[])"); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); m.def( - "ep_moe_pre_reorder(Tensor input_ptr, Tensor gateup_input_ptr, Tensor src2dst_ptr, Tensor topk_ids_ptr, Tensor " - "a1_scales_ptr, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()"); + "ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor " + "a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()"); m.impl("ep_moe_pre_reorder", torch::kCUDA, &ep_moe_pre_reorder); + m.def( + "ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor " + "topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()"); + m.impl("ep_moe_post_reorder", torch::kCUDA, &ep_moe_post_reorder); m.def( "fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor " "a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor " diff --git a/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu b/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu index a1e8856e8..f2811e98f 100644 --- a/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu +++ b/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu @@ -67,6 +67,57 @@ __global__ void ep_pre_reorder_cuda_kernel( } } +template +__global__ void ep_post_reorder_cuda_kernel( + const scalar_t* __restrict__ down_output_ptr, + scalar_t* __restrict__ output_ptr, + const int* __restrict__ src2dst_ptr, + const int* __restrict__ topk_ids_ptr, + const scalar_t* __restrict__ topk_weights_ptr, + int start_expert_id, + int end_expert_id, + int topk, + int hidden_size) { + const int token_idx = blockIdx.x; + const int tid = threadIdx.x; + + const int* token_src2dst = src2dst_ptr + token_idx * topk; + const int* token_topk_ids = topk_ids_ptr + token_idx * topk; + const scalar_t* token_topk_weights = topk_weights_ptr + token_idx * topk; + + scalar_t* dst_ptr = output_ptr + static_cast(token_idx) * hidden_size; + + constexpr uint32_t vec_size = 16 / sizeof(scalar_t); + using vec_t = flashinfer::vec_t; + + const int vec_iters = hidden_size / vec_size; + for (int idx = tid; idx < vec_iters; idx += blockDim.x) { + float acc[vec_size] = {0}; + + for (int k = 0; k < topk; ++k) { + const int expert_id = token_topk_ids[k]; + if (expert_id < start_expert_id || expert_id > end_expert_id) continue; + const int src_row = token_src2dst[k]; + const scalar_t* src_ptr = down_output_ptr + static_cast(src_row) * hidden_size; + const float weight = static_cast(token_topk_weights[k]); + + vec_t src_vec; + src_vec.cast_load(src_ptr + idx * vec_size); + +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + acc[i] += static_cast(src_vec[i]) * weight; + } + } + vec_t out_vec; +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) + out_vec[i] = static_cast(acc[i]); + + out_vec.cast_store(dst_ptr + idx * vec_size); + } +} + void ep_moe_pre_reorder( torch::Tensor input, torch::Tensor gateup_input, @@ -77,8 +128,8 @@ void ep_moe_pre_reorder( int64_t end_expert_id, int64_t topk, bool use_per_token_if_dynamic) { - int total_blocks = input.size(0); - int block_size = 512; + const int total_blocks = input.size(0); + const int block_size = 512; dim3 grid(total_blocks); dim3 block(block_size); int hidden_size = input.size(1); @@ -98,3 +149,33 @@ void ep_moe_pre_reorder( return true; }); } + +void ep_moe_post_reorder( + torch::Tensor down_output, + torch::Tensor output, + torch::Tensor src2dst, + torch::Tensor topk_ids, + torch::Tensor topk_weights, + int64_t start_expert_id, + int64_t end_expert_id, + int64_t topk) { + const int total_tokens = output.size(0); + const int block_size = 512; + dim3 grid(total_tokens); + dim3 block(block_size); + const int hidden_size = output.size(1); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(down_output.scalar_type(), scalar_t, [&] { + ep_post_reorder_cuda_kernel<<>>( + static_cast(down_output.data_ptr()), + static_cast(output.data_ptr()), + src2dst.data_ptr(), + topk_ids.data_ptr(), + static_cast(topk_weights.data_ptr()), + static_cast(start_expert_id), + static_cast(end_expert_id), + static_cast(topk), + hidden_size); + return true; + }); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 93db3d952..586f7cafe 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -264,6 +264,16 @@ void ep_moe_pre_reorder( int64_t topk, bool use_per_token_if_dynamic); +void ep_moe_post_reorder( + torch::Tensor down_output, + torch::Tensor output, + torch::Tensor src2dst, + torch::Tensor topk_ids, + torch::Tensor topk_weights, + int64_t start_expert_id, + int64_t end_expert_id, + int64_t topk); + void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor); void cutlass_fp4_group_mm( diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index a7e371456..9aef5a2b0 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -49,6 +49,7 @@ from sgl_kernel.gemm import ( from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda from sgl_kernel.moe import ( cutlass_fp4_group_mm, + ep_moe_post_reorder, ep_moe_pre_reorder, fp8_blockwise_scaled_grouped_mm, moe_align_block_size, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index c1530c322..0b60f2e18 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -88,6 +88,28 @@ def ep_moe_pre_reorder( ) +def ep_moe_post_reorder( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + start_expert_id, + end_expert_id, + topk, +): + return torch.ops.sgl_kernel.ep_moe_post_reorder.default( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + start_expert_id, + end_expert_id, + topk, + ) + + def fp8_blockwise_scaled_grouped_mm( output, a_ptrs, diff --git a/sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py b/sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py new file mode 100644 index 000000000..d20e9c9a6 --- /dev/null +++ b/sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py @@ -0,0 +1,163 @@ +import itertools + +import pytest +import torch +from sgl_kernel import ep_moe_post_reorder + +from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel + + +def create_test_tensors( + batch_size: int, + hidden_size: int, + topk: int, + start_expert_id: int, + end_expert_id: int, + dtype: torch.dtype, + device: torch.device, +): + down_output = torch.randn( + batch_size * topk, hidden_size, dtype=dtype, device=device + ) + + # Ensure src2dst has no duplicate destinations to avoid race conditions + total_tokens = batch_size * topk + dst_indices = torch.randperm(total_tokens, device=device, dtype=torch.int32) + src2dst = dst_indices.view(batch_size, topk) + + topk_ids = torch.randint( + start_expert_id, + end_expert_id + 1, + (batch_size, topk), + dtype=torch.int32, + device=device, + ) + + topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device) + + return down_output, src2dst, topk_ids, topk_weights + + +def run_cuda_kernel( + down_output: torch.Tensor, + output: torch.Tensor, + src2dst: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + start_expert_id: int, + end_expert_id: int, + topk: int, +): + ep_moe_post_reorder( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + start_expert_id, + end_expert_id, + topk, + ) + return output + + +def run_triton_kernel( + down_output: torch.Tensor, + output: torch.Tensor, + src2dst: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + start_expert_id: int, + end_expert_id: int, + topk: int, + hidden_size: int, +): + batch_size = down_output.size(0) + block_size = 512 + + post_reorder_triton_kernel[(batch_size,)]( + down_output, + output, + src2dst, + topk_ids, + topk_weights, + start_expert_id, + end_expert_id, + topk, + hidden_size, + block_size, + ) + return output + + +def assert_close(a, b): + a32, b32 = a.float(), b.float() + if a.dtype is torch.float16: + torch.testing.assert_close(a32, b32, rtol=1e-5, atol=1e-2) + elif a.dtype is torch.bfloat16: + torch.testing.assert_close(a32, b32, rtol=1e-4, atol=1e-1) + else: + torch.testing.assert_close(a32, b32, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize( + "batch_size,hidden_size,topk", + list(itertools.product([32, 64], [128, 256, 512], [2, 4, 8])), +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_ep_moe_post_reorder_vs_triton( + batch_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, +): + device = torch.device("cuda") + start_expert_id = 0 + end_expert_id = 15 + + ( + down_output, + src2dst, + topk_ids, + topk_weights, + ) = create_test_tensors( + batch_size, + hidden_size, + topk, + start_expert_id, + end_expert_id, + dtype, + device, + ) + + output_cuda = torch.empty(batch_size, hidden_size, dtype=dtype, device=device) + output_triton = torch.empty(batch_size, hidden_size, dtype=dtype, device=device) + + cuda_output = run_cuda_kernel( + down_output, + output_cuda, + src2dst, + topk_ids, + topk_weights, + start_expert_id, + end_expert_id, + topk, + ) + + triton_output = run_triton_kernel( + down_output, + output_triton, + src2dst, + topk_ids, + topk_weights, + start_expert_id, + end_expert_id, + topk, + hidden_size, + ) + + assert_close(cuda_output, triton_output) + + +if __name__ == "__main__": + pytest.main([__file__])