diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 475066a1c..f2365d70e 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -397,7 +397,9 @@ def grouped_topk_gpu( .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + topk_weights, topk_ids = torch.topk( + tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0 + ) if num_fused_shared_experts: topk_ids[:, -1] = torch.randint( low=num_experts, @@ -486,7 +488,9 @@ def biased_grouped_topk_impl( tmp_scores = scores_for_choice.masked_fill( ~score_mask.bool(), float("-inf") ) # [n, e] - _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + _, topk_ids = torch.topk( + tmp_scores, k=topk, dim=-1, sorted=num_fused_shared_experts > 0 + ) topk_weights = scores.gather(1, topk_ids) if num_fused_shared_experts: