diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 489e4563a..75c929662 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -224,6 +224,7 @@ set(SOURCES "csrc/moe/moe_topk_softmax_kernels.cu" "csrc/moe/fp8_blockwise_moe_kernel.cu" "csrc/moe/prepare_moe_input.cu" + "csrc/moe/ep_moe_reorder_kernel.cu" "csrc/speculative/eagle_utils.cu" "csrc/speculative/speculative_sampling.cu" "csrc/speculative/packbit.cu" diff --git a/sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py b/sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py new file mode 100644 index 000000000..ac0eadc32 --- /dev/null +++ b/sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py @@ -0,0 +1,100 @@ +import itertools + +import torch +import triton +import triton.language as tl +from sgl_kernel import ep_moe_pre_reorder + +from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel + +batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096] +configs = [(bs,) for bs in batch_sizes] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["cuda", "triton"], + line_names=["CUDA Kernel", "Triton Kernel"], + styles=[("green", "-"), ("orange", "-")], + ylabel="us", + plot_name="ep-moe-pre-reorder-performance", + args={}, + ) +) +def benchmark(batch_size, provider): + dtype = torch.float32 + device = torch.device("cuda") + hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512 + + # Allocate fresh tensors for every run to match bench_moe_fused_gate style + def alloc_tensors(): + input_ = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + gateup_input = torch.zeros( + batch_size * topk, hidden_size, dtype=dtype, device=device + ) + src2dst = torch.randint( + 0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device + ) + topk_ids = torch.randint( + start_expert_id, + end_expert_id + 1, + (batch_size, topk), + dtype=torch.int32, + device=device, + ) + a1_scales = torch.rand( + end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device + ) + return input_, gateup_input, src2dst, topk_ids, a1_scales + + quantiles = [0.5, 0.2, 0.8] + + if provider == "cuda": + + def run_cuda(): + inp, gout, s2d, tk_ids, scales = alloc_tensors() + ep_moe_pre_reorder( + inp, + gout, + s2d, + tk_ids, + scales, + start_expert_id, + end_expert_id, + topk, + True, + ) + + ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles) + + elif provider == "triton": + + def run_triton(): + inp, gout, s2d, tk_ids, scales = alloc_tensors() + pre_reorder_triton_kernel[(batch_size,)]( + inp.view(-1), + gout.view(-1), + s2d.view(-1), + tk_ids.view(-1), + scales, + start_expert_id, + end_expert_id, + topk, + hidden_size, + block_size, + True, + ) + + ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles) + + else: + raise ValueError(f"Unknown provider: {provider}") + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index d83944b56..2a28eb103 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -150,6 +150,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "n_share_experts_fusion, float routed_scaling_factor) -> " "(Tensor[])"); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); + m.def( + "ep_moe_pre_reorder(Tensor input_ptr, Tensor gateup_input_ptr, Tensor src2dst_ptr, Tensor topk_ids_ptr, Tensor " + "a1_scales_ptr, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()"); + m.impl("ep_moe_pre_reorder", torch::kCUDA, &ep_moe_pre_reorder); m.def( "fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor " "a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor " diff --git a/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu b/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu new file mode 100644 index 000000000..a79f123a4 --- /dev/null +++ b/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu @@ -0,0 +1,89 @@ +#include +#include +#include + +#include +#include + +#include "utils.h" + +__global__ void ep_pre_reorder_cuda_kernel( + const float* __restrict__ input_ptr, + float* __restrict__ gateup_input_ptr, + const int* __restrict__ src2dst_ptr, + const int* __restrict__ topk_ids_ptr, + const float* __restrict__ a1_scales_ptr, + int start_expert_id, + int end_expert_id, + int topk, + int hidden_size, + bool use_per_token_if_dynamic) { + int token_idx = blockIdx.x; + int tid = threadIdx.x; + + const float* src_ptr = input_ptr + int64_t(token_idx) * hidden_size; + const int* token_src2dst = src2dst_ptr + token_idx * topk; + const int* token_topk_ids = topk_ids_ptr + token_idx * topk; + + for (int k = 0; k < topk; ++k) { + int expert_id = token_topk_ids[k]; + if (expert_id < start_expert_id || expert_id > end_expert_id) continue; + + float scale = 1.0f; + + if (a1_scales_ptr != nullptr and use_per_token_if_dynamic) { + scale = 1.0f / a1_scales_ptr[token_idx]; + } + + if (a1_scales_ptr != nullptr) { + if (!use_per_token_if_dynamic) { + scale = 1.0f / a1_scales_ptr[expert_id - start_expert_id]; + } + } + + int dst_idx = token_src2dst[k]; + float* dst_ptr = gateup_input_ptr + int64_t(dst_idx) * hidden_size; + + constexpr uint32_t vec_size = 16 / sizeof(float); + using vec_t = flashinfer::vec_t; + + for (int idx = tid; idx < hidden_size / vec_size; idx += blockDim.x) { + vec_t input_vec, output_vec; + input_vec.cast_load(src_ptr + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + float val = static_cast(input_vec[i]); + output_vec[i] = val * scale; + } + output_vec.cast_store(dst_ptr + idx * vec_size); + } + } +} + +void ep_moe_pre_reorder( + torch::Tensor input, + torch::Tensor gateup_input, + torch::Tensor src2dst, + torch::Tensor topk_ids, + torch::Tensor a1_scales, + int64_t start_expert_id, + int64_t end_expert_id, + int64_t topk, + bool use_per_token_if_dynamic) { + int total_blocks = input.size(0); + int block_size = 512; + dim3 grid(total_blocks); + dim3 block(block_size); + int hidden_size = input.size(1); + ep_pre_reorder_cuda_kernel<<>>( + input.data_ptr(), + gateup_input.data_ptr(), + src2dst.data_ptr(), + topk_ids.data_ptr(), + a1_scales.defined() ? a1_scales.data_ptr() : nullptr, + start_expert_id, + end_expert_id, + topk, + hidden_size, + use_per_token_if_dynamic); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index b5e376dc8..9c3432f46 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -240,6 +240,17 @@ void prepare_moe_input( const int64_t n, const int64_t k); +void ep_moe_pre_reorder( + torch::Tensor input, + torch::Tensor gateup_input, + torch::Tensor src2dst, + torch::Tensor topk_ids, + torch::Tensor a1_scales, + int64_t start_expert_id, + int64_t end_expert_id, + int64_t topk, + bool use_per_token_if_dynamic); + /* * From csrc/speculative */ diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index ec97fa4b5..002d8d394 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -46,6 +46,7 @@ from sgl_kernel.gemm import ( ) from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda from sgl_kernel.moe import ( + ep_moe_pre_reorder, fp8_blockwise_scaled_grouped_mm, moe_align_block_size, moe_fused_gate, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index e7b5eede0..27808494d 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -62,6 +62,30 @@ def moe_fused_gate( ) +def ep_moe_pre_reorder( + input_tensor, + gateup_input, + src2dst, + topk_ids, + a1_scales, + start_expert_id, + end_expert_id, + topk, + use_per_token_if_dynamic, +): + return torch.ops.sgl_kernel.ep_moe_pre_reorder.default( + input_tensor, + gateup_input, + src2dst, + topk_ids, + a1_scales, + start_expert_id, + end_expert_id, + topk, + use_per_token_if_dynamic, + ) + + def fp8_blockwise_scaled_grouped_mm( output, a_ptrs,