diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index a1c36fff0..c0ab4e0e2 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -146,7 +146,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.def( - "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> " + "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int " + "n_share_experts_fusion, float routed_scaling_factor) -> " "(Tensor[])"); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); diff --git a/sgl-kernel/csrc/moe/moe_fused_gate.cu b/sgl-kernel/csrc/moe/moe_fused_gate.cu index c8aa4811f..6f0c474ba 100644 --- a/sgl-kernel/csrc/moe/moe_fused_gate.cu +++ b/sgl-kernel/csrc/moe/moe_fused_gate.cu @@ -57,6 +57,8 @@ __device__ void moe_fused_gate_impl( int64_t num_rows, int64_t topk_group, int64_t topk, + int64_t n_share_experts_fusion, + double routed_scaling_factor, Params params) { int tidx = threadIdx.x; int64_t thread_row = @@ -65,6 +67,9 @@ __device__ void moe_fused_gate_impl( return; } + // Calculate topk_excluding_share_expert_fusion from topk + int64_t topk_excluding_share_expert_fusion = topk - (n_share_experts_fusion > 0 ? 1 : 0); + // Cast pointers to type T: auto* input_ptr = reinterpret_cast(input); auto* bias_ptr = reinterpret_cast(bias); @@ -163,7 +168,7 @@ __device__ void moe_fused_gate_impl( ////////////////////// Topk ////////////////////// float output_sum = 0.0f; - for (int k_idx = 0; k_idx < topk; ++k_idx) { + for (int k_idx = 0; k_idx < topk_excluding_share_expert_fusion; ++k_idx) { // local argmax T max_val = bias_chunk[0]; int expert = first_elt_read_by_thread; @@ -181,7 +186,7 @@ __device__ void moe_fused_gate_impl( max_val = static_cast(-FLT_MAX); } -// argmax reduce + // argmax reduce #pragma unroll for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) { T other_max = @@ -195,36 +200,46 @@ __device__ void moe_fused_gate_impl( } } - if (k_idx < topk) { - int thread_to_clear_in_group = expert / params.VPT; - int64_t idx = topk * thread_row + k_idx; + int thread_to_clear_in_group = expert / params.VPT; + int64_t idx = topk * thread_row + k_idx; - if (thread_group_idx == thread_to_clear_in_group) { - int expert_to_clear_in_thread = expert % params.VPT; + if (thread_group_idx == thread_to_clear_in_group) { + int expert_to_clear_in_thread = expert % params.VPT; - // clear the max value in the thread - bias_chunk[expert_to_clear_in_thread] = static_cast(-FLT_MAX); + // clear the max value in the thread + bias_chunk[expert_to_clear_in_thread] = static_cast(-FLT_MAX); - // store output - output_ptr[idx] = static_cast(row_chunk[expert_to_clear_in_thread]); - indices_ptr[idx] = static_cast(expert); - } + // store output + output_ptr[idx] = static_cast(row_chunk[expert_to_clear_in_thread]); + indices_ptr[idx] = static_cast(expert); + } - // accumulate sum - if (thread_group_idx == 0) { - output_sum += output_ptr[idx]; - } + // accumulate sum for all elements + if (thread_group_idx == 0) { + output_sum += output_ptr[idx]; } __syncthreads(); } + if (thread_group_idx == 0 && n_share_experts_fusion > 0) { + int64_t last_idx = topk * thread_row + topk_excluding_share_expert_fusion; + + // Use round-robin to select expert + int64_t expert_offset = thread_row % n_share_experts_fusion; + indices_ptr[last_idx] = static_cast(params.NUM_EXPERTS + expert_offset); + + // Set the weight to the sum of all weights divided by routed_scaling_factor + output_ptr[last_idx] = output_sum / routed_scaling_factor; + } + __syncthreads(); + ////////////////////// Rescale Output ////////////////////// if (thread_group_idx == 0) { #pragma unroll for (int ii = 0; ii < topk; ++ii) { int64_t const idx = topk * thread_row + ii; - output_ptr[idx] = static_cast(static_cast(output_ptr[idx]) / static_cast(output_sum)); + output_ptr[idx] = output_ptr[idx] / output_sum; } } } @@ -257,9 +272,21 @@ __global__ void moe_fused_gate_kernel( int32_t* indices_ptr, int64_t num_rows, int64_t topk_group, - int64_t topk) { + int64_t topk, + int64_t n_share_experts_fusion, + double routed_scaling_factor) { KernelParams params; - moe_fused_gate_impl(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params); + moe_fused_gate_impl( + input, + bias, + output_ptr, + indices_ptr, + num_rows, + topk_group, + topk, + n_share_experts_fusion, + routed_scaling_factor, + params); } // Macro to compute compile-time constants and launch the kernel. @@ -277,7 +304,9 @@ __global__ void moe_fused_gate_kernel( indices.data_ptr(), \ num_rows, \ topk_group, \ - topk); \ + topk, \ + n_share_experts_fusion, \ + routed_scaling_factor); \ dispatched = true; \ } while (0) @@ -303,7 +332,9 @@ __global__ void moe_fused_gate_kernel_dynamic( int64_t num_experts, int64_t num_expert_group, int64_t topk_group, - int64_t topk) { + int64_t topk, + int64_t n_share_experts_fusion, + double routed_scaling_factor) { KernelParamsDynamic params; params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256 params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32 @@ -312,14 +343,30 @@ __global__ void moe_fused_gate_kernel_dynamic( params.ROWS_PER_WARP = std::max(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32 params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP; - moe_fused_gate_impl(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params); + moe_fused_gate_impl( + input, + bias, + output_ptr, + indices_ptr, + num_rows, + topk_group, + topk, + n_share_experts_fusion, + routed_scaling_factor, + params); } //------------------------------------------------------------------------------ // Host Launcher Function //------------------------------------------------------------------------------ -std::vector -moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk) { +std::vector moe_fused_gate( + at::Tensor& input, + at::Tensor& bias, + int64_t num_expert_group, + int64_t topk_group, + int64_t topk, + int64_t n_share_experts_fusion, + double routed_scaling_factor) { int64_t num_rows = input.size(0); int32_t num_experts = input.size(1); auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); @@ -416,7 +463,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in num_experts, num_expert_group, topk_group, - topk); + topk, + n_share_experts_fusion, + routed_scaling_factor); } else if (input.scalar_type() == at::kHalf) { moe_fused_gate_kernel_dynamic<<>>( input.data_ptr(), @@ -427,7 +476,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in num_experts, num_expert_group, topk_group, - topk); + topk, + n_share_experts_fusion, + routed_scaling_factor); } else if (input.scalar_type() == at::kFloat) { moe_fused_gate_kernel_dynamic<<>>( input.data_ptr(), @@ -438,7 +489,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in num_experts, num_expert_group, topk_group, - topk); + topk, + n_share_experts_fusion, + routed_scaling_factor); } else { TORCH_CHECK(false, "Unsupported data type for moe_fused_gate"); } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 8a6d1c44b..10df9d1c7 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -200,8 +200,14 @@ void topk_softmax( torch::Tensor& token_expert_indices, torch::Tensor& gating_output); -std::vector -moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk); +std::vector moe_fused_gate( + at::Tensor& input, + at::Tensor& bias, + int64_t num_expert_group, + int64_t topk_group, + int64_t topk, + int64_t n_share_experts_fusion, + double routed_scaling_factor); /* * From csrc/speculative diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index 1067a1760..afabc44f9 100644 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -34,13 +34,29 @@ def topk_softmax( ) -def moe_fused_gate(input_tensor, bias, num_expert_group, topk_group, topk): +def moe_fused_gate( + input_tensor, + bias, + num_expert_group, + topk_group, + topk, + n_share_experts_fusion=0, + routed_scaling_factor=0, +): # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion # it split group of expert into num_expert_group, and use top2 expert weight sum in each group # as the group weight to select exerpt groups and then select topk experts within the selected groups # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now. # for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk + # n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert + # routed_scaling_factor: if > 0, the last expert will be scaled by this factor return torch.ops.sgl_kernel.moe_fused_gate.default( - input_tensor, bias, num_expert_group, topk_group, topk + input_tensor, + bias, + num_expert_group, + topk_group, + topk, + n_share_experts_fusion, + routed_scaling_factor, ) diff --git a/sgl-kernel/tests/test_moe_fused_gate.py b/sgl-kernel/tests/test_moe_fused_gate.py index 3d6221bf4..82404e572 100644 --- a/sgl-kernel/tests/test_moe_fused_gate.py +++ b/sgl-kernel/tests/test_moe_fused_gate.py @@ -19,13 +19,15 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk (512, 16, 8, 16), ], ) -def test_moe_fused_gate_combined(seq_length, dtype, params): +@pytest.mark.parametrize("n_share_experts_fusion", [0, 1, 8, 16]) +def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusion): num_experts, num_expert_group, topk_group, topk = params torch.manual_seed(seq_length) tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda() scores = tensor.clone() bias = torch.rand(num_experts).to(dtype).cuda() + topk = topk + min(1, n_share_experts_fusion) output, indices = moe_fused_gate( tensor, @@ -33,6 +35,8 @@ def test_moe_fused_gate_combined(seq_length, dtype, params): num_expert_group=num_expert_group, topk_group=topk_group, topk=topk, + n_share_experts_fusion=n_share_experts_fusion, + routed_scaling_factor=2.5, ) ref_output, ref_indices = biased_grouped_topk( scores, @@ -43,8 +47,30 @@ def test_moe_fused_gate_combined(seq_length, dtype, params): num_expert_group=num_expert_group, topk_group=topk_group, compiled=False, + n_share_experts_fusion=n_share_experts_fusion, ) + # When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension + if n_share_experts_fusion > 0: + original_indices = indices.clone() + original_ref_indices = ref_indices.clone() + + indices = indices[:, :-1] + ref_indices = ref_indices[:, :-1] + + valid_min = num_experts + valid_max = num_experts + n_share_experts_fusion + shared_indices = original_indices[:, -1] + shared_ref_indices = original_ref_indices[:, -1] + if shared_indices is not None: + assert torch.all( + (shared_indices >= valid_min) & (shared_indices < valid_max) + ), f"Shared expert indices out of range: found values outside [{valid_min}, {valid_max})" + if shared_ref_indices is not None: + assert torch.all( + (shared_ref_indices >= valid_min) & (shared_ref_indices < valid_max) + ), f"Shared expert reference indices out of range: found values outside [{valid_min}, {valid_max})" + idx_check = torch.allclose( ref_indices.sort()[0].to(torch.int32), indices.sort()[0].to(torch.int32), @@ -54,17 +80,17 @@ def test_moe_fused_gate_combined(seq_length, dtype, params): output_check = torch.allclose( ref_output.sort()[0].to(torch.float32), output.sort()[0].to(torch.float32), - rtol=1e-04, - atol=1e-05, + rtol=1e-02, + atol=1e-03, ) assert idx_check, ( f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, " - f"params {params}" + f"params {params}, n_share_experts_fusion {n_share_experts_fusion}" ) assert output_check, ( f"Output mismatch at seq_length {seq_length}, dtype {dtype}, " - f"params {params}" + f"params {params}, n_share_experts_fusion {n_share_experts_fusion}" )