diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 14c07b642..f5dceac78 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -89,6 +89,23 @@ def fused_topk( ) del token_expert_indicies + return _fused_topk_postprocess( + topk_weights=topk_weights, + topk_ids=topk_ids, + renormalize=renormalize, + expert_location_dispatch_info=expert_location_dispatch_info, + num_token_non_padded=num_token_non_padded, + ) + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def _fused_topk_postprocess( + topk_weights, + topk_ids, + renormalize, + expert_location_dispatch_info, + num_token_non_padded, +): if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) @@ -313,7 +330,6 @@ def select_experts( num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ): - router_logits, correction_bias = ( expert_location_dispatch.transform_select_experts_inputs( router_logits=router_logits,