[torch.compile bug] avoid biased_grouped_topk_impl func repeatedly triggering torch.compile in forward pass (#8353)
This commit is contained in:
@@ -387,6 +387,7 @@ def grouped_topk_cpu(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
|
||||||
def biased_grouped_topk_impl(
|
def biased_grouped_topk_impl(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
@@ -482,7 +483,6 @@ def biased_grouped_topk_gpu(
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
num_expert_group: int = 0,
|
num_expert_group: int = 0,
|
||||||
topk_group: int = 0,
|
topk_group: int = 0,
|
||||||
compiled: bool = not _is_npu,
|
|
||||||
num_fused_shared_experts: int = 0,
|
num_fused_shared_experts: int = 0,
|
||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
@@ -535,14 +535,7 @@ def biased_grouped_topk_gpu(
|
|||||||
)
|
)
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
else:
|
else:
|
||||||
biased_grouped_topk_fn = (
|
return biased_grouped_topk_impl(
|
||||||
torch.compile(
|
|
||||||
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
|
|
||||||
)
|
|
||||||
if compiled
|
|
||||||
else biased_grouped_topk_impl
|
|
||||||
)
|
|
||||||
return biased_grouped_topk_fn(
|
|
||||||
hidden_states,
|
hidden_states,
|
||||||
gating_output,
|
gating_output,
|
||||||
correction_bias,
|
correction_bias,
|
||||||
|
|||||||
Reference in New Issue
Block a user