diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 87389ec4b..13ef9ce49 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -282,8 +282,6 @@ set(SOURCES "csrc/moe/nvfp4_blockwise_moe.cu" "csrc/moe/fp8_blockwise_moe_kernel.cu" "csrc/moe/prepare_moe_input.cu" - "csrc/moe/ep_moe_reorder_kernel.cu" - "csrc/moe/ep_moe_silu_and_mul_kernel.cu" "csrc/memory/store.cu" "csrc/kvcacheio/transfer.cu" diff --git a/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py b/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py index 078e2c131..faadd7698 100644 --- a/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py +++ b/sgl-kernel/benchmark/bench_moe_ep_post_reorder.py @@ -1,6 +1,5 @@ import torch import triton -from sgl_kernel import ep_moe_post_reorder from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel @@ -13,9 +12,9 @@ configs = [(bs,) for bs in batch_sizes] x_names=["batch_size"], x_vals=[list(_) for _ in configs], line_arg="provider", - line_vals=["cuda", "triton"], - line_names=["CUDA Kernel", "Triton Kernel"], - styles=[("green", "-"), ("orange", "-")], + line_vals=["triton"], + line_names=["Triton Kernel"], + styles=[("orange", "-")], ylabel="us", plot_name="ep-moe-post-reorder-performance", args={}, @@ -46,24 +45,7 @@ def benchmark(batch_size, provider): quantiles = [0.5, 0.2, 0.8] - if provider == "cuda": - d_out, out, s2d, tk_ids, tk_weights = alloc_tensors() - - def run_cuda(): - ep_moe_post_reorder( - d_out, - out, - s2d, - tk_ids, - tk_weights, - start_expert_id, - end_expert_id, - topk, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles) - - elif provider == "triton": + if provider == "triton": d_out, out, s2d, tk_ids, tk_weights = alloc_tensors() def run_triton(): diff --git a/sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py b/sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py deleted file mode 100644 index 7623d3109..000000000 --- a/sgl-kernel/benchmark/bench_moe_ep_pre_reorder.py +++ /dev/null @@ -1,103 +0,0 @@ -import torch -import triton -from sgl_kernel import ep_moe_pre_reorder - -from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel - -batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096] -configs = [(bs,) for bs in batch_sizes] - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size"], - x_vals=[list(_) for _ in configs], - line_arg="provider", - line_vals=["cuda", "triton"], - line_names=["CUDA Kernel", "Triton Kernel"], - styles=[("green", "-"), ("orange", "-")], - ylabel="us", - plot_name="ep-moe-pre-reorder-performance", - args={}, - ) -) -def benchmark(batch_size, provider): - dtype = torch.bfloat16 - device = torch.device("cuda") - hidden_size, topk, start_expert_id, end_expert_id, block_size = ( - 4096, - 8, - 0, - 255, - 512, - ) - - # Allocate fresh tensors for every run to match bench_moe_fused_gate style - def alloc_tensors(): - input_ = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) - gateup_input = torch.zeros( - batch_size * topk, hidden_size, dtype=dtype, device=device - ) - src2dst = torch.randint( - 0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device - ) - topk_ids = torch.randint( - start_expert_id, - end_expert_id + 1, - (batch_size, topk), - dtype=torch.int32, - device=device, - ) - a1_scales = torch.rand( - end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device - ) - return input_, gateup_input, src2dst, topk_ids, a1_scales - - quantiles = [0.5, 0.2, 0.8] - - if provider == "cuda": - inp, gout, s2d, tk_ids, scales = alloc_tensors() - - def run_cuda(): - ep_moe_pre_reorder( - inp, - gout, - s2d, - tk_ids, - scales, - start_expert_id, - end_expert_id, - topk, - True, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles) - - elif provider == "triton": - inp, gout, s2d, tk_ids, scales = alloc_tensors() - - def run_triton(): - pre_reorder_triton_kernel[(batch_size,)]( - inp.view(-1), - gout.view(-1), - s2d.view(-1), - tk_ids.view(-1), - scales, - start_expert_id, - end_expert_id, - topk, - hidden_size, - block_size, - True, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles) - - else: - raise ValueError(f"Unknown provider: {provider}") - - return 1000 * ms, 1000 * max_ms, 1000 * min_ms - - -if __name__ == "__main__": - benchmark.run(print_data=True) diff --git a/sgl-kernel/benchmark/bench_moe_silu_and_mul.py b/sgl-kernel/benchmark/bench_moe_silu_and_mul.py deleted file mode 100644 index 68f54bd32..000000000 --- a/sgl-kernel/benchmark/bench_moe_silu_and_mul.py +++ /dev/null @@ -1,92 +0,0 @@ -import itertools - -import torch -import triton -from sgl_kernel import ep_moe_silu_and_mul - -from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_triton_kernel - -batch_size_range = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096] -hidden_size_range = [1024, 2048, 4096, 8192] -block_size_range = [128, 256, 512] -configs = list(itertools.product(batch_size_range, hidden_size_range, block_size_range)) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size", "hidden_size", "block_size"], - x_vals=[list(cfg) for cfg in configs], - line_arg="provider", - line_vals=["cuda", "triton"], - line_names=["CUDA Kernel", "Triton Kernel"], - styles=[("green", "-"), ("orange", "-")], - ylabel="us", - plot_name="ep-moe-silu-and-mul-performance", - args={}, - ) -) -def benchmark(batch_size, hidden_size, block_size, provider): - dtype = torch.bfloat16 - device = torch.device("cuda") - - half_hidden_size = hidden_size // 2 - start_expert_id, end_expert_id = 0, 255 - block_size = 512 - quantiles = [0.5, 0.2, 0.8] - - def alloc_tensors(): - gateup_output = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) - down_input = torch.empty( - batch_size, half_hidden_size, dtype=dtype, device=device - ) - reorder_topk_ids = torch.randint( - start_expert_id, - end_expert_id + 1, - (batch_size,), - dtype=torch.int32, - device=device, - ) - scales = torch.rand( - end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device - ) - return gateup_output, down_input, reorder_topk_ids, scales - - if provider == "cuda": - gateup, down, ids, scales = alloc_tensors() - - def run_cuda(): - ep_moe_silu_and_mul( - gateup, - down, - ids, - scales, - start_expert_id, - end_expert_id, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles) - - elif provider == "triton": - gateup, down, ids, scales = alloc_tensors() - - def run_triton(): - silu_and_mul_triton_kernel[(batch_size,)]( - gateup.view(-1), - down.view(-1), - hidden_size, - ids, - scales, - start_expert_id, - end_expert_id, - block_size, - ) - - ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles) - else: - raise ValueError(f"Unknown provider: {provider}") - - return 1000 * ms, 1000 * max_ms, 1000 * min_ms - - -if __name__ == "__main__": - benchmark.run(print_data=True) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 18a141af1..5a87dd483 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -209,18 +209,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "num_fused_shared_experts, float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> " "(Tensor[])"); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); - m.def( - "ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor " - "a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()"); - m.impl("ep_moe_pre_reorder", torch::kCUDA, &ep_moe_pre_reorder); - m.def( - "ep_moe_silu_and_mul(Tensor gateup_output, Tensor down_input, Tensor reorder_topk_ids, Tensor scales, int " - "start_expert_id, int end_expert_id) -> ()"); - m.impl("ep_moe_silu_and_mul", torch::kCUDA, &ep_moe_silu_and_mul); - m.def( - "ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor " - "topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()"); - m.impl("ep_moe_post_reorder", torch::kCUDA, &ep_moe_post_reorder); m.def( "fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor " "a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor " diff --git a/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu b/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu deleted file mode 100644 index f2811e98f..000000000 --- a/sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu +++ /dev/null @@ -1,181 +0,0 @@ -#include -#include -#include - -#include -#include - -#include "utils.h" - -template -__global__ void ep_pre_reorder_cuda_kernel( - const scalar_t* __restrict__ input_ptr, - scalar_t* __restrict__ gateup_input_ptr, - const int* __restrict__ src2dst_ptr, - const int* __restrict__ topk_ids_ptr, - const float* __restrict__ a1_scales_ptr, - int start_expert_id, - int end_expert_id, - int topk, - int hidden_size, - bool use_per_token_if_dynamic) { - int token_idx = blockIdx.x; - int tid = threadIdx.x; - - const scalar_t* src_ptr = input_ptr + int64_t(token_idx) * hidden_size; - const int* token_src2dst = src2dst_ptr + token_idx * topk; - const int* token_topk_ids = topk_ids_ptr + token_idx * topk; - - float scale = 1.0f; - - if (a1_scales_ptr != nullptr and use_per_token_if_dynamic) { - scale = 1.0f / a1_scales_ptr[token_idx]; - } - - for (int k = 0; k < topk; ++k) { - int expert_id = token_topk_ids[k]; - if (expert_id < start_expert_id || expert_id > end_expert_id) continue; - - if (a1_scales_ptr != nullptr) { - if (!use_per_token_if_dynamic) { - scale = 1.0f / a1_scales_ptr[expert_id - start_expert_id]; - } - } - - int dst_idx = token_src2dst[k]; - scalar_t* dst_ptr = gateup_input_ptr + int64_t(dst_idx) * hidden_size; - - constexpr uint32_t vec_size = 16 / sizeof(scalar_t); - using vec_t = flashinfer::vec_t; - - int vec_elements = (hidden_size / vec_size) * vec_size; - for (int idx = tid; idx < hidden_size / vec_size; idx += blockDim.x) { - vec_t input_vec, output_vec; - input_vec.cast_load(src_ptr + idx * vec_size); -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - float val = static_cast(input_vec[i]); - output_vec[i] = static_cast(val * scale); - } - output_vec.cast_store(dst_ptr + idx * vec_size); - } - - for (int idx = vec_elements + tid; idx < hidden_size; idx += blockDim.x) { - float val = static_cast(src_ptr[idx]); - dst_ptr[idx] = static_cast(val * scale); - } - } -} - -template -__global__ void ep_post_reorder_cuda_kernel( - const scalar_t* __restrict__ down_output_ptr, - scalar_t* __restrict__ output_ptr, - const int* __restrict__ src2dst_ptr, - const int* __restrict__ topk_ids_ptr, - const scalar_t* __restrict__ topk_weights_ptr, - int start_expert_id, - int end_expert_id, - int topk, - int hidden_size) { - const int token_idx = blockIdx.x; - const int tid = threadIdx.x; - - const int* token_src2dst = src2dst_ptr + token_idx * topk; - const int* token_topk_ids = topk_ids_ptr + token_idx * topk; - const scalar_t* token_topk_weights = topk_weights_ptr + token_idx * topk; - - scalar_t* dst_ptr = output_ptr + static_cast(token_idx) * hidden_size; - - constexpr uint32_t vec_size = 16 / sizeof(scalar_t); - using vec_t = flashinfer::vec_t; - - const int vec_iters = hidden_size / vec_size; - for (int idx = tid; idx < vec_iters; idx += blockDim.x) { - float acc[vec_size] = {0}; - - for (int k = 0; k < topk; ++k) { - const int expert_id = token_topk_ids[k]; - if (expert_id < start_expert_id || expert_id > end_expert_id) continue; - const int src_row = token_src2dst[k]; - const scalar_t* src_ptr = down_output_ptr + static_cast(src_row) * hidden_size; - const float weight = static_cast(token_topk_weights[k]); - - vec_t src_vec; - src_vec.cast_load(src_ptr + idx * vec_size); - -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - acc[i] += static_cast(src_vec[i]) * weight; - } - } - vec_t out_vec; -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) - out_vec[i] = static_cast(acc[i]); - - out_vec.cast_store(dst_ptr + idx * vec_size); - } -} - -void ep_moe_pre_reorder( - torch::Tensor input, - torch::Tensor gateup_input, - torch::Tensor src2dst, - torch::Tensor topk_ids, - torch::Tensor a1_scales, - int64_t start_expert_id, - int64_t end_expert_id, - int64_t topk, - bool use_per_token_if_dynamic) { - const int total_blocks = input.size(0); - const int block_size = 512; - dim3 grid(total_blocks); - dim3 block(block_size); - int hidden_size = input.size(1); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { - ep_pre_reorder_cuda_kernel<<>>( - static_cast(input.data_ptr()), - static_cast(gateup_input.data_ptr()), - src2dst.data_ptr(), - topk_ids.data_ptr(), - a1_scales.defined() ? a1_scales.data_ptr() : nullptr, - start_expert_id, - end_expert_id, - topk, - hidden_size, - use_per_token_if_dynamic); - return true; - }); -} - -void ep_moe_post_reorder( - torch::Tensor down_output, - torch::Tensor output, - torch::Tensor src2dst, - torch::Tensor topk_ids, - torch::Tensor topk_weights, - int64_t start_expert_id, - int64_t end_expert_id, - int64_t topk) { - const int total_tokens = output.size(0); - const int block_size = 512; - dim3 grid(total_tokens); - dim3 block(block_size); - const int hidden_size = output.size(1); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(down_output.scalar_type(), scalar_t, [&] { - ep_post_reorder_cuda_kernel<<>>( - static_cast(down_output.data_ptr()), - static_cast(output.data_ptr()), - src2dst.data_ptr(), - topk_ids.data_ptr(), - static_cast(topk_weights.data_ptr()), - static_cast(start_expert_id), - static_cast(end_expert_id), - static_cast(topk), - hidden_size); - return true; - }); -} diff --git a/sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu b/sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu deleted file mode 100644 index 4bbea8ac8..000000000 --- a/sgl-kernel/csrc/moe/ep_moe_silu_and_mul_kernel.cu +++ /dev/null @@ -1,115 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "utils.h" - -using namespace flashinfer; - -template -__device__ inline scalar_t silu_quantize(float x); - -template <> -__device__ inline float silu_quantize(float x) { - float y = x / (1.f + __expf(-x)); - return y; -} - -template <> -__device__ inline __half silu_quantize<__half>(float x) { - float y = x / (1.f + __expf(-x)); - return __float2half_rn(y); -} - -template <> -__device__ inline __nv_bfloat16 silu_quantize<__nv_bfloat16>(float x) { - float y = x / (1.f + __expf(-x)); - return __float2bfloat16_rn(y); -} - -template -__global__ void ep_moe_act_and_mul_cuda_kernel( - const scalar_t* __restrict__ gateup_output, - scalar_t* __restrict__ down_input, - const int* __restrict__ reorder_topk_ids, - const float* __restrict__ scales, - int start_expert_id, - int end_expert_id, - int hidden_size) { - constexpr uint32_t vec_size = 16 / sizeof(scalar_t); - using vec_t = flashinfer::vec_t; - - const int64_t token_idx = blockIdx.x; - const int64_t thread_idx = threadIdx.x; - const int64_t stride = blockDim.x; - - const int half_hidden_size = hidden_size >> 1; - const int expert_id = reorder_topk_ids[token_idx]; - - if (expert_id < start_expert_id || expert_id > end_expert_id) return; - const scalar_t* gate_output_ptr = gateup_output + static_cast(token_idx) * hidden_size; - const scalar_t* up_output_ptr = gate_output_ptr + half_hidden_size; - scalar_t* dst_ptr = down_input + static_cast(token_idx) * half_hidden_size; - scalar_t scale_q = static_cast(scales ? (1.f / scales[expert_id - start_expert_id]) : 1.f); - - const uint32_t vec_elements = half_hidden_size / vec_size; -#pragma unroll 1 - for (uint32_t idx = thread_idx; idx < vec_elements; idx += stride) { - vec_t gate_vec, up_vec, out_vec; - gate_vec.load(gate_output_ptr + idx * vec_size); - up_vec.load(up_output_ptr + idx * vec_size); - -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - float gate_f = static_cast(gate_vec[i]); - scalar_t gate_q = silu_quantize(gate_f); - scalar_t prod = gate_q * up_vec[i] * scale_q; - out_vec[i] = prod; - } - out_vec.store(dst_ptr + idx * vec_size); - } - - const int64_t scalar_start = static_cast(vec_elements) * vec_size + thread_idx; -#pragma unroll 1 - for (int64_t idx = scalar_start; idx < half_hidden_size; idx += stride) { - float gate_f = static_cast(gate_output_ptr[idx]); - scalar_t gate_q = silu_quantize(gate_f); - dst_ptr[idx] = gate_q * up_output_ptr[idx] * scale_q; - } -} - -void ep_moe_silu_and_mul( - torch::Tensor gateup_output, - torch::Tensor down_input, - torch::Tensor reorder_topk_ids, - torch::Tensor scales, - int64_t start_expert_id, - int64_t end_expert_id) { - const int total_tokens = gateup_output.size(0); - const int hidden_size = gateup_output.size(1); - - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(gateup_output.scalar_type(), scalar_t, [&] { - dim3 grid(total_tokens); - constexpr uint32_t vec_size = 16 / sizeof(scalar_t); - const int half_hidden_size = hidden_size >> 1; - uint32_t threads = (half_hidden_size + vec_size - 1) / vec_size; - threads = std::max(threads, 256); - threads = ((threads + 31) & ~31U); - dim3 block(std::min(threads, 1024U)); - ep_moe_act_and_mul_cuda_kernel<<>>( - static_cast(gateup_output.data_ptr()), - static_cast(down_input.data_ptr()), - reorder_topk_ids.data_ptr(), - scales.defined() ? scales.data_ptr() : nullptr, - static_cast(start_expert_id), - static_cast(end_expert_id), - hidden_size); - return true; - }); -} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 0b4b979ab..76969a6ee 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -325,35 +325,6 @@ void prepare_moe_input( const int64_t n, const int64_t k); -void ep_moe_pre_reorder( - torch::Tensor input, - torch::Tensor gateup_input, - torch::Tensor src2dst, - torch::Tensor topk_ids, - torch::Tensor a1_scales, - int64_t start_expert_id, - int64_t end_expert_id, - int64_t topk, - bool use_per_token_if_dynamic); - -void ep_moe_silu_and_mul( - torch::Tensor gateup_output, - torch::Tensor down_input, - torch::Tensor reorder_topk_ids, - torch::Tensor scales, - int64_t start_expert_id, - int64_t end_expert_id); - -void ep_moe_post_reorder( - torch::Tensor down_output, - torch::Tensor output, - torch::Tensor src2dst, - torch::Tensor topk_ids, - torch::Tensor topk_weights, - int64_t start_expert_id, - int64_t end_expert_id, - int64_t topk); - void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor); void apply_shuffle_mul_sum( diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 0476ad696..25e4eaf3b 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -77,9 +77,6 @@ from sgl_kernel.memory import set_kv_buffer_kernel from sgl_kernel.moe import ( apply_shuffle_mul_sum, cutlass_fp4_group_mm, - ep_moe_post_reorder, - ep_moe_pre_reorder, - ep_moe_silu_and_mul, fp8_blockwise_scaled_grouped_mm, moe_align_block_size, moe_fused_gate, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index 9008e7a79..66fec9f2b 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -71,70 +71,6 @@ def moe_fused_gate( ) -def ep_moe_pre_reorder( - input_tensor, - gateup_input, - src2dst, - topk_ids, - a1_scales, - start_expert_id, - end_expert_id, - topk, - use_per_token_if_dynamic, -): - return torch.ops.sgl_kernel.ep_moe_pre_reorder.default( - input_tensor, - gateup_input, - src2dst, - topk_ids, - a1_scales, - start_expert_id, - end_expert_id, - topk, - use_per_token_if_dynamic, - ) - - -def ep_moe_silu_and_mul( - gateup_output, - down_input, - reorder_topk_ids, - scales, - start_expert_id, - end_expert_id, -): - return torch.ops.sgl_kernel.ep_moe_silu_and_mul.default( - gateup_output, - down_input, - reorder_topk_ids, - scales, - start_expert_id, - end_expert_id, - ) - - -def ep_moe_post_reorder( - down_output, - output, - src2dst, - topk_ids, - topk_weights, - start_expert_id, - end_expert_id, - topk, -): - return torch.ops.sgl_kernel.ep_moe_post_reorder.default( - down_output, - output, - src2dst, - topk_ids, - topk_weights, - start_expert_id, - end_expert_id, - topk, - ) - - def fp8_blockwise_scaled_grouped_mm( output, a_ptrs, diff --git a/sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py b/sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py deleted file mode 100644 index 189173559..000000000 --- a/sgl-kernel/tests/test_ep_moe_post_reorder_kernel.py +++ /dev/null @@ -1,164 +0,0 @@ -import itertools - -import pytest -import torch -from sgl_kernel import ep_moe_post_reorder - -from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel - - -def create_test_tensors( - batch_size: int, - hidden_size: int, - topk: int, - start_expert_id: int, - end_expert_id: int, - dtype: torch.dtype, - device: torch.device, -): - down_output = torch.randn( - batch_size * topk, hidden_size, dtype=dtype, device=device - ) - - # Ensure src2dst has no duplicate destinations to avoid race conditions - total_tokens = batch_size * topk - dst_indices = torch.randperm(total_tokens, device=device, dtype=torch.int32) - src2dst = dst_indices.view(batch_size, topk) - - topk_ids = torch.randint( - start_expert_id, - end_expert_id + 1, - (batch_size, topk), - dtype=torch.int32, - device=device, - ) - - topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device) - - return down_output, src2dst, topk_ids, topk_weights - - -def run_cuda_kernel( - down_output: torch.Tensor, - output: torch.Tensor, - src2dst: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - start_expert_id: int, - end_expert_id: int, - topk: int, -): - ep_moe_post_reorder( - down_output, - output, - src2dst, - topk_ids, - topk_weights, - start_expert_id, - end_expert_id, - topk, - ) - return output - - -def run_triton_kernel( - down_output: torch.Tensor, - output: torch.Tensor, - src2dst: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - start_expert_id: int, - end_expert_id: int, - topk: int, - hidden_size: int, -): - batch_size = down_output.size(0) - block_size = 512 - - post_reorder_triton_kernel[(batch_size,)]( - down_output, - output, - src2dst, - topk_ids, - topk_weights, - start_expert_id, - end_expert_id, - topk, - hidden_size, - 0, - block_size, - ) - return output - - -def assert_close(a, b): - a32, b32 = a.float(), b.float() - if a.dtype is torch.float16: - torch.testing.assert_close(a32, b32, rtol=1e-5, atol=1e-2) - elif a.dtype is torch.bfloat16: - torch.testing.assert_close(a32, b32, rtol=1e-4, atol=1e-1) - else: - torch.testing.assert_close(a32, b32, rtol=1e-5, atol=1e-5) - - -@pytest.mark.parametrize( - "batch_size,hidden_size,topk", - list(itertools.product([32, 64], [128, 256, 512], [2, 4, 8])), -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) -def test_ep_moe_post_reorder_vs_triton( - batch_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, -): - device = torch.device("cuda") - start_expert_id = 0 - end_expert_id = 15 - - ( - down_output, - src2dst, - topk_ids, - topk_weights, - ) = create_test_tensors( - batch_size, - hidden_size, - topk, - start_expert_id, - end_expert_id, - dtype, - device, - ) - - output_cuda = torch.empty(batch_size, hidden_size, dtype=dtype, device=device) - output_triton = torch.empty(batch_size, hidden_size, dtype=dtype, device=device) - - cuda_output = run_cuda_kernel( - down_output, - output_cuda, - src2dst, - topk_ids, - topk_weights, - start_expert_id, - end_expert_id, - topk, - ) - - triton_output = run_triton_kernel( - down_output, - output_triton, - src2dst, - topk_ids, - topk_weights, - start_expert_id, - end_expert_id, - topk, - hidden_size, - ) - - assert_close(cuda_output, triton_output) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_ep_moe_pre_reorder_kernel.py b/sgl-kernel/tests/test_ep_moe_pre_reorder_kernel.py deleted file mode 100644 index 718f633c9..000000000 --- a/sgl-kernel/tests/test_ep_moe_pre_reorder_kernel.py +++ /dev/null @@ -1,181 +0,0 @@ -import itertools - -import pytest -import torch -from sgl_kernel import ep_moe_pre_reorder - -from sglang.srt.layers.moe.ep_moe.kernels import pre_reorder_triton_kernel - - -def create_test_tensors( - batch_size: int, - hidden_size: int, - topk: int, - start_expert_id: int, - end_expert_id: int, - dtype: torch.dtype, - device: torch.device, - use_per_token_if_dynamic: bool = True, -): - input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) - - # Ensure src2dst has no duplicate destinations to avoid race conditions - total_tokens = batch_size * topk - dst_indices = torch.randperm(total_tokens, device=device, dtype=torch.int32) - src2dst = dst_indices.view(batch_size, topk) - - topk_ids = torch.randint( - start_expert_id, - end_expert_id + 1, - (batch_size, topk), - dtype=torch.int32, - device=device, - ) - - if use_per_token_if_dynamic: - a1_scales = ( - torch.rand(batch_size, dtype=torch.float32, device=device) * 0.8 + 0.2 - ) - else: - a1_scales = ( - torch.rand( - end_expert_id - start_expert_id + 1, dtype=torch.float32, device=device - ) - * 0.8 - + 0.2 - ) - - return input_tensor, src2dst, topk_ids, a1_scales - - -def run_cuda_kernel( - input_tensor: torch.Tensor, - gateup_input: torch.Tensor, - src2dst: torch.Tensor, - topk_ids: torch.Tensor, - a1_scales: torch.Tensor, - start_expert_id: int, - end_expert_id: int, - topk: int, - use_per_token_if_dynamic: bool, -): - ep_moe_pre_reorder( - input_tensor, - gateup_input, - src2dst, - topk_ids, - a1_scales, - start_expert_id, - end_expert_id, - topk, - use_per_token_if_dynamic, - ) - return gateup_input - - -def run_triton_kernel( - input_tensor: torch.Tensor, - gateup_input: torch.Tensor, - src2dst: torch.Tensor, - topk_ids: torch.Tensor, - a1_scales: torch.Tensor, - start_expert_id: int, - end_expert_id: int, - topk: int, - hidden_size: int, - use_per_token_if_dynamic: bool, -): - batch_size = input_tensor.size(0) - block_size = 512 - - pre_reorder_triton_kernel[(batch_size,)]( - input_tensor, - gateup_input, - src2dst, - topk_ids, - a1_scales, - start_expert_id, - end_expert_id, - topk, - hidden_size, - block_size, - use_per_token_if_dynamic, - ) - return gateup_input - - -@pytest.mark.parametrize( - "batch_size,hidden_size,topk", - list(itertools.product([32, 64, 128], [512, 1024, 2048], [4, 8])), -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) -@pytest.mark.parametrize("use_per_token_if_dynamic", [True, False]) -def test_ep_moe_pre_reorder_vs_triton( - batch_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_per_token_if_dynamic: bool, -): - device = torch.device("cuda") - start_expert_id = 0 - end_expert_id = 15 - - ( - input_tensor, - src2dst, - topk_ids, - a1_scales, - ) = create_test_tensors( - batch_size, - hidden_size, - topk, - start_expert_id, - end_expert_id, - dtype, - device, - use_per_token_if_dynamic, - ) - - gateup_input_cuda = torch.empty( - batch_size * topk, hidden_size, dtype=dtype, device=device - ) - gateup_input_triton = torch.empty( - batch_size * topk, hidden_size, dtype=dtype, device=device - ) - - cuda_output = run_cuda_kernel( - input_tensor, - gateup_input_cuda, - src2dst, - topk_ids, - a1_scales, - start_expert_id, - end_expert_id, - topk, - use_per_token_if_dynamic, - ) - - triton_output = run_triton_kernel( - input_tensor, - gateup_input_triton, - src2dst, - topk_ids, - a1_scales, - start_expert_id, - end_expert_id, - topk, - hidden_size, - use_per_token_if_dynamic, - ) - - torch.testing.assert_close( - cuda_output.float(), - triton_output.float(), - rtol=1e-5, - atol=1e-5, - ) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py b/sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py deleted file mode 100644 index 7039c5086..000000000 --- a/sgl-kernel/tests/test_ep_moe_silu_and_mul_kernel.py +++ /dev/null @@ -1,142 +0,0 @@ -import itertools - -import pytest -import torch -from sgl_kernel import ep_moe_silu_and_mul - -from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_triton_kernel - - -def create_test_tensors( - total_tokens: int, - hidden_size: int, - start_expert_id: int, - end_expert_id: int, - dtype: torch.dtype, - device: torch.device, -): - gateup_output = torch.randn(total_tokens, hidden_size, dtype=dtype, device=device) - - reorder_topk_ids = torch.randint( - start_expert_id, - end_expert_id + 1, - (total_tokens,), - dtype=torch.int32, - device=device, - ) - - num_experts = end_expert_id - start_expert_id + 1 - scales = torch.rand(num_experts, dtype=torch.float32, device=device) * 0.8 + 0.5 - - half_hidden = hidden_size // 2 - down_input = torch.empty(total_tokens, half_hidden, dtype=dtype, device=device) - - return gateup_output, down_input, reorder_topk_ids, scales - - -def run_cuda_kernel( - gateup_output: torch.Tensor, - down_input: torch.Tensor, - reorder_topk_ids: torch.Tensor, - scales: torch.Tensor, - start_expert_id: int, - end_expert_id: int, -): - ep_moe_silu_and_mul( - gateup_output, - down_input, - reorder_topk_ids, - scales, - start_expert_id, - end_expert_id, - ) - return down_input - - -def run_triton_kernel( - gateup_output: torch.Tensor, - down_input: torch.Tensor, - reorder_topk_ids: torch.Tensor, - scales: torch.Tensor, - start_expert_id: int, - end_expert_id: int, - hidden_size: int, -): - total_tokens = gateup_output.size(0) - block_size = 512 - - silu_and_mul_triton_kernel[(total_tokens,)]( - gateup_output, - down_input, - hidden_size, - reorder_topk_ids, - scales, - start_expert_id, - end_expert_id, - block_size, - ) - return down_input - - -@pytest.mark.parametrize( - "total_tokens,hidden_size", - list(itertools.product([32, 256, 1024], [128, 256, 512])), -) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) -def test_ep_moe_silu_and_mul_vs_triton( - total_tokens: int, - hidden_size: int, - dtype: torch.dtype, -): - device = torch.device("cuda") - start_expert_id = 0 - end_expert_id = 15 - - ( - gateup_output, - _, - reorder_topk_ids, - scales, - ) = create_test_tensors( - total_tokens, - hidden_size, - start_expert_id, - end_expert_id, - dtype, - device, - ) - - down_input_cuda = torch.empty( - total_tokens, hidden_size // 2, dtype=dtype, device=device - ) - down_input_triton = torch.empty_like(down_input_cuda) - - cuda_output = run_cuda_kernel( - gateup_output, - down_input_cuda, - reorder_topk_ids, - scales, - start_expert_id, - end_expert_id, - ) - - triton_output = run_triton_kernel( - gateup_output, - down_input_triton, - reorder_topk_ids, - scales, - start_expert_id, - end_expert_id, - hidden_size, - ) - - torch.testing.assert_close( - cuda_output, - triton_output, - rtol=1e-5, - atol=1e-5, - ) - - -if __name__ == "__main__": - pytest.main([__file__])