diff --git a/sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py b/sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py index ac0eadc32..7623d3109 100644 --- a/sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py +++ b/sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py @@ -1,8 +1,5 @@ -import itertools - import torch import triton -import triton.language as tl from sgl_kernel import ep_moe_pre_reorder from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel @@ -25,9 +22,15 @@ configs = [(bs,) for bs in batch_sizes] ) ) def benchmark(batch_size, provider): - dtype = torch.float32 + dtype = torch.bfloat16 device = torch.device("cuda") - hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512 + hidden_size, topk, start_expert_id, end_expert_id, block_size = ( + 4096, + 8, + 0, + 255, + 512, + ) # Allocate fresh tensors for every run to match bench_moe_fused_gate style def alloc_tensors(): @@ -53,9 +56,9 @@ def benchmark(batch_size, provider): quantiles = [0.5, 0.2, 0.8] if provider == "cuda": + inp, gout, s2d, tk_ids, scales = alloc_tensors() def run_cuda(): - inp, gout, s2d, tk_ids, scales = alloc_tensors() ep_moe_pre_reorder( inp, gout, @@ -71,9 +74,9 @@ def benchmark(batch_size, provider): ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles) elif provider == "triton": + inp, gout, s2d, tk_ids, scales = alloc_tensors() def run_triton(): - inp, gout, s2d, tk_ids, scales = alloc_tensors() pre_reorder_triton_kernel[(batch_size,)]( inp.view(-1), gout.view(-1), diff --git a/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu b/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu index a79f123a4..a1e8856e8 100644 --- a/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu +++ b/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu @@ -7,9 +7,10 @@ #include "utils.h" +template __global__ void ep_pre_reorder_cuda_kernel( - const float* __restrict__ input_ptr, - float* __restrict__ gateup_input_ptr, + const scalar_t* __restrict__ input_ptr, + scalar_t* __restrict__ gateup_input_ptr, const int* __restrict__ src2dst_ptr, const int* __restrict__ topk_ids_ptr, const float* __restrict__ a1_scales_ptr, @@ -21,20 +22,20 @@ __global__ void ep_pre_reorder_cuda_kernel( int token_idx = blockIdx.x; int tid = threadIdx.x; - const float* src_ptr = input_ptr + int64_t(token_idx) * hidden_size; + const scalar_t* src_ptr = input_ptr + int64_t(token_idx) * hidden_size; const int* token_src2dst = src2dst_ptr + token_idx * topk; const int* token_topk_ids = topk_ids_ptr + token_idx * topk; + float scale = 1.0f; + + if (a1_scales_ptr != nullptr and use_per_token_if_dynamic) { + scale = 1.0f / a1_scales_ptr[token_idx]; + } + for (int k = 0; k < topk; ++k) { int expert_id = token_topk_ids[k]; if (expert_id < start_expert_id || expert_id > end_expert_id) continue; - float scale = 1.0f; - - if (a1_scales_ptr != nullptr and use_per_token_if_dynamic) { - scale = 1.0f / a1_scales_ptr[token_idx]; - } - if (a1_scales_ptr != nullptr) { if (!use_per_token_if_dynamic) { scale = 1.0f / a1_scales_ptr[expert_id - start_expert_id]; @@ -42,21 +43,27 @@ __global__ void ep_pre_reorder_cuda_kernel( } int dst_idx = token_src2dst[k]; - float* dst_ptr = gateup_input_ptr + int64_t(dst_idx) * hidden_size; + scalar_t* dst_ptr = gateup_input_ptr + int64_t(dst_idx) * hidden_size; - constexpr uint32_t vec_size = 16 / sizeof(float); - using vec_t = flashinfer::vec_t; + constexpr uint32_t vec_size = 16 / sizeof(scalar_t); + using vec_t = flashinfer::vec_t; + int vec_elements = (hidden_size / vec_size) * vec_size; for (int idx = tid; idx < hidden_size / vec_size; idx += blockDim.x) { vec_t input_vec, output_vec; input_vec.cast_load(src_ptr + idx * vec_size); #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { float val = static_cast(input_vec[i]); - output_vec[i] = val * scale; + output_vec[i] = static_cast(val * scale); } output_vec.cast_store(dst_ptr + idx * vec_size); } + + for (int idx = vec_elements + tid; idx < hidden_size; idx += blockDim.x) { + float val = static_cast(src_ptr[idx]); + dst_ptr[idx] = static_cast(val * scale); + } } } @@ -75,15 +82,19 @@ void ep_moe_pre_reorder( dim3 grid(total_blocks); dim3 block(block_size); int hidden_size = input.size(1); - ep_pre_reorder_cuda_kernel<<>>( - input.data_ptr(), - gateup_input.data_ptr(), - src2dst.data_ptr(), - topk_ids.data_ptr(), - a1_scales.defined() ? a1_scales.data_ptr() : nullptr, - start_expert_id, - end_expert_id, - topk, - hidden_size, - use_per_token_if_dynamic); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { + ep_pre_reorder_cuda_kernel<<>>( + static_cast(input.data_ptr()), + static_cast(gateup_input.data_ptr()), + src2dst.data_ptr(), + topk_ids.data_ptr(), + a1_scales.defined() ? a1_scales.data_ptr() : nullptr, + start_expert_id, + end_expert_id, + topk, + hidden_size, + use_per_token_if_dynamic); + return true; + }); } diff --git a/sgl-kernel/tests/test_ep_moe_pre_reorder_kernel.py b/sgl-kernel/tests/test_ep_moe_pre_reorder_kernel.py new file mode 100644 index 000000000..718f633c9 --- /dev/null +++ b/sgl-kernel/tests/test_ep_moe_pre_reorder_kernel.py @@ -0,0 +1,181 @@ +import itertools + +import pytest +import torch +from sgl_kernel import ep_moe_pre_reorder + +from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel + + +def create_test_tensors( + batch_size: int, + hidden_size: int, + topk: int, + start_expert_id: int, + end_expert_id: int, + dtype: torch.dtype, + device: torch.device, + use_per_token_if_dynamic: bool = True, +): + input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + + # Ensure src2dst has no duplicate destinations to avoid race conditions + total_tokens = batch_size * topk + dst_indices = torch.randperm(total_tokens, device=device, dtype=torch.int32) + src2dst = dst_indices.view(batch_size, topk) + + topk_ids = torch.randint( + start_expert_id, + end_expert_id + 1, + (batch_size, topk), + dtype=torch.int32, + device=device, + ) + + if use_per_token_if_dynamic: + a1_scales = ( + torch.rand(batch_size, dtype=torch.float32, device=device) * 0.8 + 0.2 + ) + else: + a1_scales = ( + torch.rand( + end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device + ) + * 0.8 + + 0.2 + ) + + return input_tensor, src2dst, topk_ids, a1_scales + + +def run_cuda_kernel( + input_tensor: torch.Tensor, + gateup_input: torch.Tensor, + src2dst: torch.Tensor, + topk_ids: torch.Tensor, + a1_scales: torch.Tensor, + start_expert_id: int, + end_expert_id: int, + topk: int, + use_per_token_if_dynamic: bool, +): + ep_moe_pre_reorder( + input_tensor, + gateup_input, + src2dst, + topk_ids, + a1_scales, + start_expert_id, + end_expert_id, + topk, + use_per_token_if_dynamic, + ) + return gateup_input + + +def run_triton_kernel( + input_tensor: torch.Tensor, + gateup_input: torch.Tensor, + src2dst: torch.Tensor, + topk_ids: torch.Tensor, + a1_scales: torch.Tensor, + start_expert_id: int, + end_expert_id: int, + topk: int, + hidden_size: int, + use_per_token_if_dynamic: bool, +): + batch_size = input_tensor.size(0) + block_size = 512 + + pre_reorder_triton_kernel[(batch_size,)]( + input_tensor, + gateup_input, + src2dst, + topk_ids, + a1_scales, + start_expert_id, + end_expert_id, + topk, + hidden_size, + block_size, + use_per_token_if_dynamic, + ) + return gateup_input + + +@pytest.mark.parametrize( + "batch_size,hidden_size,topk", + list(itertools.product([32, 64, 128], [512, 1024, 2048], [4, 8])), +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("use_per_token_if_dynamic", [True, False]) +def test_ep_moe_pre_reorder_vs_triton( + batch_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_per_token_if_dynamic: bool, +): + device = torch.device("cuda") + start_expert_id = 0 + end_expert_id = 15 + + ( + input_tensor, + src2dst, + topk_ids, + a1_scales, + ) = create_test_tensors( + batch_size, + hidden_size, + topk, + start_expert_id, + end_expert_id, + dtype, + device, + use_per_token_if_dynamic, + ) + + gateup_input_cuda = torch.empty( + batch_size * topk, hidden_size, dtype=dtype, device=device + ) + gateup_input_triton = torch.empty( + batch_size * topk, hidden_size, dtype=dtype, device=device + ) + + cuda_output = run_cuda_kernel( + input_tensor, + gateup_input_cuda, + src2dst, + topk_ids, + a1_scales, + start_expert_id, + end_expert_id, + topk, + use_per_token_if_dynamic, + ) + + triton_output = run_triton_kernel( + input_tensor, + gateup_input_triton, + src2dst, + topk_ids, + a1_scales, + start_expert_id, + end_expert_id, + topk, + hidden_size, + use_per_token_if_dynamic, + ) + + torch.testing.assert_close( + cuda_output.float(), + triton_output.float(), + rtol=1e-5, + atol=1e-5, + ) + + +if __name__ == "__main__": + pytest.main([__file__])