From a9dd3ec3e961e9fa9fb666c294c496079bacc156 Mon Sep 17 00:00:00 2001 From: erictanjn <142883585+erictanjn@users.noreply.github.com> Date: Mon, 28 Jul 2025 20:36:46 +0800 Subject: [PATCH] fix:reorder topk experts to ensure shared expert replaces minimal score (#8125) --- python/sglang/srt/layers/moe/topk.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 475066a1c..f2365d70e 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -397,7 +397,9 @@ def grouped_topk_gpu( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + topk_weights, topk_ids = torch.topk( + tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0 + ) if num_fused_shared_experts: topk_ids[:, -1] = torch.randint( low=num_experts, @@ -486,7 +488,9 @@ def biased_grouped_topk_impl( tmp_scores = scores_for_choice.masked_fill( ~score_mask.bool(), float("-inf") ) # [n, e] - _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + _, topk_ids = torch.topk( + tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0 + ) topk_weights = scores.gather(1, topk_ids) if num_fused_shared_experts: