Make torch compile configurable for biased_grouped_topk (#4749)
This commit is contained in:
@@ -129,8 +129,7 @@ def grouped_topk(
|
|||||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
def biased_grouped_topk_impl(
|
||||||
def biased_grouped_topk(
|
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
correction_bias: torch.Tensor,
|
correction_bias: torch.Tensor,
|
||||||
@@ -171,6 +170,34 @@ def biased_grouped_topk(
|
|||||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def biased_grouped_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
correction_bias: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
num_expert_group: int = 0,
|
||||||
|
topk_group: int = 0,
|
||||||
|
compiled: bool = True,
|
||||||
|
):
|
||||||
|
biased_grouped_topk_fn = (
|
||||||
|
torch.compile(
|
||||||
|
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
|
||||||
|
)
|
||||||
|
if compiled
|
||||||
|
else biased_grouped_topk_impl
|
||||||
|
)
|
||||||
|
return biased_grouped_topk_fn(
|
||||||
|
hidden_states,
|
||||||
|
gating_output,
|
||||||
|
correction_bias,
|
||||||
|
topk,
|
||||||
|
renormalize,
|
||||||
|
num_expert_group,
|
||||||
|
topk_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def select_experts(
|
def select_experts(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user