From 044c3159706c190191205ea91bb9ebd2ba292bf4 Mon Sep 17 00:00:00 2001 From: Qingquan Song Date: Fri, 28 Mar 2025 10:57:52 -0700 Subject: [PATCH] Make torch compile configurable for biased_grouped_topk (#4749) --- python/sglang/srt/layers/moe/topk.py | 31 ++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index cfa264a23..29984f3f2 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -129,8 +129,7 @@ def grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) -@torch.compile(dynamic=True, backend=get_compiler_backend()) -def biased_grouped_topk( +def biased_grouped_topk_impl( hidden_states: torch.Tensor, gating_output: torch.Tensor, correction_bias: torch.Tensor, @@ -171,6 +170,34 @@ def biased_grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +def biased_grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, + compiled: bool = True, +): + biased_grouped_topk_fn = ( + torch.compile( + biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend() + ) + if compiled + else biased_grouped_topk_impl + ) + return biased_grouped_topk_fn( + hidden_states, + gating_output, + correction_bias, + topk, + renormalize, + num_expert_group, + topk_group, + ) + + def select_experts( hidden_states: torch.Tensor, router_logits: torch.Tensor,