From b102353f8f2d464de6d2796d62e87878513ccdf6 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sun, 3 Aug 2025 17:03:04 -0700 Subject: [PATCH] [MoE] Enable `renormalize=False` in Triton kernels (#8735) --- python/sglang/srt/layers/moe/topk.py | 26 ++++---------------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index b372858f7..c346e12f7 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -183,15 +183,13 @@ class TopK(CustomOp): *, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, - sm_first: bool = False, # only used for triton kernels topk ) -> TopKOutput: if self.use_triton_kernels: - return triton_kernels_topk( - router_logits=router_logits, - topk=self.top_k, - renormalize=self.renormalize, - sm_first=sm_first, + # renormalize=True is equivalent to sm_first=False + routing_data, gather_idx, scatter_idx = routing( + router_logits, self.top_k, sm_first=not self.renormalize ) + return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) else: torch_native = False return select_experts( @@ -647,22 +645,6 @@ def biased_grouped_topk_cpu( ) -def triton_kernels_topk( - router_logits: torch.Tensor, - topk: int, - renormalize: bool = False, - sm_first: bool = False, -) -> TritonKernelTopKOutput: - """Top-K routing for Triton kernels MoE.""" - assert not renormalize, "Triton kernels topk doesn't support renormalize" - routing_data, gather_idx, scatter_idx = routing( - logits=router_logits, - n_expts_act=topk, - sm_first=sm_first, - ) - return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) - - if _is_cpu and _is_cpu_amx_available: biased_grouped_topk = biased_grouped_topk_cpu grouped_topk = grouped_topk_cpu