diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index b372858f7..c346e12f7 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -183,15 +183,13 @@ class TopK(CustomOp): *, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, - sm_first: bool = False, # only used for triton kernels topk ) -> TopKOutput: if self.use_triton_kernels: - return triton_kernels_topk( - router_logits=router_logits, - topk=self.top_k, - renormalize=self.renormalize, - sm_first=sm_first, + # renormalize=True is equivalent to sm_first=False + routing_data, gather_idx, scatter_idx = routing( + router_logits, self.top_k, sm_first=not self.renormalize ) + return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) else: torch_native = False return select_experts( @@ -647,22 +645,6 @@ def biased_grouped_topk_cpu( ) -def triton_kernels_topk( - router_logits: torch.Tensor, - topk: int, - renormalize: bool = False, - sm_first: bool = False, -) -> TritonKernelTopKOutput: - """Top-K routing for Triton kernels MoE.""" - assert not renormalize, "Triton kernels topk doesn't support renormalize" - routing_data, gather_idx, scatter_idx = routing( - logits=router_logits, - n_expts_act=topk, - sm_first=sm_first, - ) - return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) - - if _is_cpu and _is_cpu_amx_available: biased_grouped_topk = biased_grouped_topk_cpu grouped_topk = grouped_topk_cpu