[torch.compile bug] avoid biased_grouped_topk_impl func repeatedly triggering torch.compile in forward pass (#8353)

This commit is contained in:
Xiaoyu Zhang
2025-07-25 21:17:47 +08:00
committed by GitHub
parent 70e37b97bf
commit 9045cc1eb8
2 changed files with 3 additions and 10 deletions

View File

@@ -387,6 +387,7 @@ def grouped_topk_cpu(
)
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
def biased_grouped_topk_impl(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@@ -482,7 +483,6 @@ def biased_grouped_topk_gpu(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = not _is_npu,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
@@ -535,14 +535,7 @@ def biased_grouped_topk_gpu(
)
return topk_weights, topk_ids
else:
biased_grouped_topk_fn = (
torch.compile(
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
)
if compiled
else biased_grouped_topk_impl
)
return biased_grouped_topk_fn(
return biased_grouped_topk_impl(
hidden_states,
gating_output,
correction_bias,