Make torch compile configurable for biased_grouped_topk (#4749)

This commit is contained in:
Qingquan Song
2025-03-28 10:57:52 -07:00
committed by GitHub
parent 4db29e82ec
commit 044c315970

View File

@@ -129,8 +129,7 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def biased_grouped_topk(
def biased_grouped_topk_impl(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
@@ -171,6 +170,34 @@ def biased_grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def biased_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = True,
):
biased_grouped_topk_fn = (
torch.compile(
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
)
if compiled
else biased_grouped_topk_impl
)
return biased_grouped_topk_fn(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
)
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,