From 0242bb9c7437d7b597d8145b5db61f888614e5f9 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Mon, 4 Aug 2025 01:45:15 +0800 Subject: [PATCH] Fix triton kernels topk with keyword arguments (#8732) --- python/sglang/srt/layers/moe/topk.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index f2365d70e..b372858f7 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -183,12 +183,15 @@ class TopK(CustomOp): *, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + sm_first: bool = False, # only used for triton kernels topk ) -> TopKOutput: if self.use_triton_kernels: - routing_data, gather_idx, scatter_idx = routing( - router_logits, self.top_k, self.renormalize + return triton_kernels_topk( + router_logits=router_logits, + topk=self.top_k, + renormalize=self.renormalize, + sm_first=sm_first, ) - return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) else: torch_native = False return select_experts( @@ -644,6 +647,22 @@ def biased_grouped_topk_cpu( ) +def triton_kernels_topk( + router_logits: torch.Tensor, + topk: int, + renormalize: bool = False, + sm_first: bool = False, +) -> TritonKernelTopKOutput: + """Top-K routing for Triton kernels MoE.""" + assert not renormalize, "Triton kernels topk doesn't support renormalize" + routing_data, gather_idx, scatter_idx = routing( + logits=router_logits, + n_expts_act=topk, + sm_first=sm_first, + ) + return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) + + if _is_cpu and _is_cpu_amx_available: biased_grouped_topk = biased_grouped_topk_cpu grouped_topk = grouped_topk_cpu