From 38076dea84253c9f1e1df3fe0fb889c585e2e128 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Tue, 15 Apr 2025 07:24:26 +0800 Subject: [PATCH] apply fused moe gate in ds v3/r1 (#5371) Co-authored-by: Yineng Zhang --- python/sglang/srt/layers/moe/topk.py | 53 +++++++++++++++++++--------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 53c36c63a..c12b9d019 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -12,6 +12,7 @@ # limitations under the License. # ============================================================================== +import math import os from typing import Callable, Optional @@ -25,6 +26,8 @@ from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip _is_cuda = is_cuda() _is_hip = is_hip() +if _is_cuda: + from sgl_kernel import moe_fused_gate expert_distribution_recorder = ExpertDistributionRecorder() @@ -209,6 +212,10 @@ def biased_grouped_topk_impl( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +def is_power_of_two(n): + return n > 0 and math.log2(n).is_integer() + + def biased_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -220,23 +227,37 @@ def biased_grouped_topk( compiled: bool = True, n_share_experts_fusion: int = 0, ): - biased_grouped_topk_fn = ( - torch.compile( - biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend() + # TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now. + if ( + _is_cuda + and n_share_experts_fusion == 0 + and is_power_of_two(correction_bias.shape[0]) + ): + return moe_fused_gate( + gating_output, + correction_bias, + num_expert_group, + topk_group, + topk, + ) + else: + biased_grouped_topk_fn = ( + torch.compile( + biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend() + ) + if compiled + else biased_grouped_topk_impl + ) + return biased_grouped_topk_fn( + hidden_states, + gating_output, + correction_bias, + topk, + renormalize, + num_expert_group, + topk_group, + n_share_experts_fusion=n_share_experts_fusion, ) - if compiled - else biased_grouped_topk_impl - ) - return biased_grouped_topk_fn( - hidden_states, - gating_output, - correction_bias, - topk, - renormalize, - num_expert_group, - topk_group, - n_share_experts_fusion=n_share_experts_fusion, - ) def select_experts(