From 9045cc1eb8daa77e6d4d271e3bdebc6e26584303 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Fri, 25 Jul 2025 21:17:47 +0800 Subject: [PATCH] [torch.compile bug] avoid biased_grouped_topk_impl func repeatedly triggering `torch.compile` in forward pass (#8353) --- docs/references/hardware.rst | 2 +- python/sglang/srt/layers/moe/topk.py | 11 ++--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/docs/references/hardware.rst b/docs/references/hardware.rst index ea37b2b49..5be98e7cd 100644 --- a/docs/references/hardware.rst +++ b/docs/references/hardware.rst @@ -5,4 +5,4 @@ Hardware Supports amd.md nvidia_jetson.md - cpu.md \ No newline at end of file + cpu.md diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index a806a4052..ce00fb9c8 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -387,6 +387,7 @@ def grouped_topk_cpu( ) +@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu) def biased_grouped_topk_impl( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -482,7 +483,6 @@ def biased_grouped_topk_gpu( renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, - compiled: bool = not _is_npu, num_fused_shared_experts: int = 0, routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, @@ -535,14 +535,7 @@ def biased_grouped_topk_gpu( ) return topk_weights, topk_ids else: - biased_grouped_topk_fn = ( - torch.compile( - biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend() - ) - if compiled - else biased_grouped_topk_impl - ) - return biased_grouped_topk_fn( + return biased_grouped_topk_impl( hidden_states, gating_output, correction_bias,