From 0c227ee373acb4ccf220d46a2fb1c89c65bd8339 Mon Sep 17 00:00:00 2001 From: zixuanzhang226 Date: Fri, 21 Feb 2025 00:30:15 -0800 Subject: [PATCH] feat: update grouped_topk to support softmax and sigmoid (#3680) --- python/sglang/srt/layers/moe/topk.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 91ca00c6e..e808a0a20 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -75,7 +75,6 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek V2/V3/R1 series models @torch.compile(dynamic=True, backend=get_compiler_backend()) def grouped_topk( hidden_states: torch.Tensor, @@ -84,10 +83,17 @@ def grouped_topk( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, + scoring_func: str = "softmax", ): assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Scoring function '{scoring_func}' is not supported.") + num_token = scores.shape[0] group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values @@ -111,6 +117,7 @@ def grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +# DeepSeek V2/V3/R1 uses biased_grouped_top @torch.compile(dynamic=True, backend=get_compiler_backend()) def biased_grouped_topk( hidden_states: torch.Tensor, @@ -165,7 +172,7 @@ def select_experts( correction_bias: Optional[torch.Tensor] = None, torch_native: bool = False, ): - # DeekSeekv2 uses grouped_top_k + # DeepSeek V2/V3/R1 uses biased_grouped_top if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None