diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index f5dceac78..0c3d92b66 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -249,6 +249,15 @@ def _mask_topk_ids_padded_region( topk_ids[indices >= num_token_non_padded, :] = -1 +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def _biased_grouped_topk_postprocess( + topk_ids, expert_location_dispatch_info, num_token_non_padded +): + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + return topk_ids + + def biased_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -282,14 +291,13 @@ def biased_grouped_topk( num_fused_shared_experts, routed_scaling_factor, ) - # TODO merge into kernel for this branch - topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) - # TODO will fuse this into kernel, thus use slow manual operation now - if num_token_non_padded is None: - return topk_weights, topk_ids - torch.compile( - _mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend() - )(topk_ids, num_token_non_padded) + # TODO merge into kernel + if (expert_location_dispatch_info is not None) or ( + num_token_non_padded is not None + ): + topk_ids = _biased_grouped_topk_postprocess( + topk_ids, expert_location_dispatch_info, num_token_non_padded + ) return topk_weights, topk_ids else: biased_grouped_topk_fn = (