diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index e80d4aa52..0e3daa3a4 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -237,6 +237,7 @@ set(SOURCES "csrc/moe/fp8_blockwise_moe_kernel.cu" "csrc/moe/prepare_moe_input.cu" "csrc/moe/ep_moe_reorder_kernel.cu" + "csrc/moe/ep_moe_silu_and_mul_kernel.cu" "csrc/speculative/eagle_utils.cu" "csrc/speculative/speculative_sampling.cu" "csrc/speculative/packbit.cu" diff --git a/sgl-kernel/benchmark/bench_moe_silu_and_mul.py b/sgl-kernel/benchmark/bench_moe_silu_and_mul.py new file mode 100644 index 000000000..68f54bd32 --- /dev/null +++ b/sgl-kernel/benchmark/bench_moe_silu_and_mul.py @@ -0,0 +1,92 @@ +import itertools + +import torch +import triton +from sgl_kernel import ep_moe_silu_and_mul + +from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_triton_kernel + +batch_size_range = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096] +hidden_size_range = [1024, 2048, 4096, 8192] +block_size_range = [128, 256, 512] +configs = list(itertools.product(batch_size_range, hidden_size_range, block_size_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "hidden_size", "block_size"], + x_vals=[list(cfg) for cfg in configs], + line_arg="provider", + line_vals=["cuda", "triton"], + line_names=["CUDA Kernel", "Triton Kernel"], + styles=[("green", "-"), ("orange", "-")], + ylabel="us", + plot_name="ep-moe-silu-and-mul-performance", + args={}, + ) +) +def benchmark(batch_size, hidden_size, block_size, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + + half_hidden_size = hidden_size // 2 + start_expert_id, end_expert_id = 0, 255 + block_size = 512 + quantiles = [0.5, 0.2, 0.8] + + def alloc_tensors(): + gateup_output = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + down_input = torch.empty( + batch_size, half_hidden_size, dtype=dtype, device=device + ) + reorder_topk_ids = torch.randint( + start_expert_id, + end_expert_id + 1, + (batch_size,), + dtype=torch.int32, + device=device, + ) + scales = torch.rand( + end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device + ) + return gateup_output, down_input, reorder_topk_ids, scales + + if provider == "cuda": + gateup, down, ids, scales = alloc_tensors() + + def run_cuda(): + ep_moe_silu_and_mul( + gateup, + down, + ids, + scales, + start_expert_id, + end_expert_id, + ) + + ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles) + + elif provider == "triton": + gateup, down, ids, scales = alloc_tensors() + + def run_triton(): + silu_and_mul_triton_kernel[(batch_size,)]( + gateup.view(-1), + down.view(-1), + hidden_size, + ids, + scales, + start_expert_id, + end_expert_id, + block_size, + ) + + ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles) + else: + raise ValueError(f"Unknown provider: {provider}") + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index f0f78f8be..1886d668e 100755 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -177,6 +177,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor " "a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()"); m.impl("ep_moe_pre_reorder", torch::kCUDA, &ep_moe_pre_reorder); + m.def( + "ep_moe_silu_and_mul(Tensor gateup_output, Tensor down_input, Tensor reorder_topk_ids, Tensor scales, int " + "start_expert_id, int end_expert_id) -> ()"); + m.impl("ep_moe_silu_and_mul", torch::kCUDA, &ep_moe_silu_and_mul); m.def( "ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor " "topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()"); diff --git a/sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu b/sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu new file mode 100644 index 000000000..4bbea8ac8 --- /dev/null +++ b/sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu @@ -0,0 +1,115 @@ +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "utils.h" + +using namespace flashinfer; + +template +__device__ inline scalar_t silu_quantize(float x); + +template <> +__device__ inline float silu_quantize(float x) { + float y = x / (1.f + __expf(-x)); + return y; +} + +template <> +__device__ inline __half silu_quantize<__half>(float x) { + float y = x / (1.f + __expf(-x)); + return __float2half_rn(y); +} + +template <> +__device__ inline __nv_bfloat16 silu_quantize<__nv_bfloat16>(float x) { + float y = x / (1.f + __expf(-x)); + return __float2bfloat16_rn(y); +} + +template +__global__ void ep_moe_act_and_mul_cuda_kernel( + const scalar_t* __restrict__ gateup_output, + scalar_t* __restrict__ down_input, + const int* __restrict__ reorder_topk_ids, + const float* __restrict__ scales, + int start_expert_id, + int end_expert_id, + int hidden_size) { + constexpr uint32_t vec_size = 16 / sizeof(scalar_t); + using vec_t = flashinfer::vec_t; + + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + + const int half_hidden_size = hidden_size >> 1; + const int expert_id = reorder_topk_ids[token_idx]; + + if (expert_id < start_expert_id || expert_id > end_expert_id) return; + const scalar_t* gate_output_ptr = gateup_output + static_cast(token_idx) * hidden_size; + const scalar_t* up_output_ptr = gate_output_ptr + half_hidden_size; + scalar_t* dst_ptr = down_input + static_cast(token_idx) * half_hidden_size; + scalar_t scale_q = static_cast(scales ? (1.f / scales[expert_id - start_expert_id]) : 1.f); + + const uint32_t vec_elements = half_hidden_size / vec_size; +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < vec_elements; idx += stride) { + vec_t gate_vec, up_vec, out_vec; + gate_vec.load(gate_output_ptr + idx * vec_size); + up_vec.load(up_output_ptr + idx * vec_size); + +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + float gate_f = static_cast(gate_vec[i]); + scalar_t gate_q = silu_quantize(gate_f); + scalar_t prod = gate_q * up_vec[i] * scale_q; + out_vec[i] = prod; + } + out_vec.store(dst_ptr + idx * vec_size); + } + + const int64_t scalar_start = static_cast(vec_elements) * vec_size + thread_idx; +#pragma unroll 1 + for (int64_t idx = scalar_start; idx < half_hidden_size; idx += stride) { + float gate_f = static_cast(gate_output_ptr[idx]); + scalar_t gate_q = silu_quantize(gate_f); + dst_ptr[idx] = gate_q * up_output_ptr[idx] * scale_q; + } +} + +void ep_moe_silu_and_mul( + torch::Tensor gateup_output, + torch::Tensor down_input, + torch::Tensor reorder_topk_ids, + torch::Tensor scales, + int64_t start_expert_id, + int64_t end_expert_id) { + const int total_tokens = gateup_output.size(0); + const int hidden_size = gateup_output.size(1); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(gateup_output.scalar_type(), scalar_t, [&] { + dim3 grid(total_tokens); + constexpr uint32_t vec_size = 16 / sizeof(scalar_t); + const int half_hidden_size = hidden_size >> 1; + uint32_t threads = (half_hidden_size + vec_size - 1) / vec_size; + threads = std::max(threads, 256); + threads = ((threads + 31) & ~31U); + dim3 block(std::min(threads, 1024U)); + ep_moe_act_and_mul_cuda_kernel<<>>( + static_cast(gateup_output.data_ptr()), + static_cast(down_input.data_ptr()), + reorder_topk_ids.data_ptr(), + scales.defined() ? scales.data_ptr() : nullptr, + static_cast(start_expert_id), + static_cast(end_expert_id), + hidden_size); + return true; + }); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 7b939b67f..a20c26724 100755 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -266,6 +266,14 @@ void ep_moe_pre_reorder( int64_t topk, bool use_per_token_if_dynamic); +void ep_moe_silu_and_mul( + torch::Tensor gateup_output, + torch::Tensor down_input, + torch::Tensor reorder_topk_ids, + torch::Tensor scales, + int64_t start_expert_id, + int64_t end_expert_id); + void ep_moe_post_reorder( torch::Tensor down_output, torch::Tensor output, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index d9ce97417..4d5065bd4 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -52,6 +52,7 @@ from sgl_kernel.moe import ( cutlass_fp4_group_mm, ep_moe_post_reorder, ep_moe_pre_reorder, + ep_moe_silu_and_mul, fp8_blockwise_scaled_grouped_mm, moe_align_block_size, moe_fused_gate, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index 176c979a9..75fbc6b42 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -88,6 +88,24 @@ def ep_moe_pre_reorder( ) +def ep_moe_silu_and_mul( + gateup_output, + down_input, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, +): + return torch.ops.sgl_kernel.ep_moe_silu_and_mul.default( + gateup_output, + down_input, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, + ) + + def ep_moe_post_reorder( down_output, output, diff --git a/sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py b/sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py new file mode 100644 index 000000000..7039c5086 --- /dev/null +++ b/sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py @@ -0,0 +1,142 @@ +import itertools + +import pytest +import torch +from sgl_kernel import ep_moe_silu_and_mul + +from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_triton_kernel + + +def create_test_tensors( + total_tokens: int, + hidden_size: int, + start_expert_id: int, + end_expert_id: int, + dtype: torch.dtype, + device: torch.device, +): + gateup_output = torch.randn(total_tokens, hidden_size, dtype=dtype, device=device) + + reorder_topk_ids = torch.randint( + start_expert_id, + end_expert_id + 1, + (total_tokens,), + dtype=torch.int32, + device=device, + ) + + num_experts = end_expert_id - start_expert_id + 1 + scales = torch.rand(num_experts, dtype=torch.float32, device=device) * 0.8 + 0.5 + + half_hidden = hidden_size // 2 + down_input = torch.empty(total_tokens, half_hidden, dtype=dtype, device=device) + + return gateup_output, down_input, reorder_topk_ids, scales + + +def run_cuda_kernel( + gateup_output: torch.Tensor, + down_input: torch.Tensor, + reorder_topk_ids: torch.Tensor, + scales: torch.Tensor, + start_expert_id: int, + end_expert_id: int, +): + ep_moe_silu_and_mul( + gateup_output, + down_input, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, + ) + return down_input + + +def run_triton_kernel( + gateup_output: torch.Tensor, + down_input: torch.Tensor, + reorder_topk_ids: torch.Tensor, + scales: torch.Tensor, + start_expert_id: int, + end_expert_id: int, + hidden_size: int, +): + total_tokens = gateup_output.size(0) + block_size = 512 + + silu_and_mul_triton_kernel[(total_tokens,)]( + gateup_output, + down_input, + hidden_size, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, + block_size, + ) + return down_input + + +@pytest.mark.parametrize( + "total_tokens,hidden_size", + list(itertools.product([32, 256, 1024], [128, 256, 512])), +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +def test_ep_moe_silu_and_mul_vs_triton( + total_tokens: int, + hidden_size: int, + dtype: torch.dtype, +): + device = torch.device("cuda") + start_expert_id = 0 + end_expert_id = 15 + + ( + gateup_output, + _, + reorder_topk_ids, + scales, + ) = create_test_tensors( + total_tokens, + hidden_size, + start_expert_id, + end_expert_id, + dtype, + device, + ) + + down_input_cuda = torch.empty( + total_tokens, hidden_size // 2, dtype=dtype, device=device + ) + down_input_triton = torch.empty_like(down_input_cuda) + + cuda_output = run_cuda_kernel( + gateup_output, + down_input_cuda, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, + ) + + triton_output = run_triton_kernel( + gateup_output, + down_input_triton, + reorder_topk_ids, + scales, + start_expert_id, + end_expert_id, + hidden_size, + ) + + torch.testing.assert_close( + cuda_output, + triton_output, + rtol=1e-5, + atol=1e-5, + ) + + +if __name__ == "__main__": + pytest.main([__file__])