Minor compile fused topk (#6944)
This commit is contained in:
@@ -89,6 +89,23 @@ def fused_topk(
|
|||||||
)
|
)
|
||||||
del token_expert_indicies
|
del token_expert_indicies
|
||||||
|
|
||||||
|
return _fused_topk_postprocess(
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
renormalize=renormalize,
|
||||||
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||||
|
num_token_non_padded=num_token_non_padded,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
|
def _fused_topk_postprocess(
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
renormalize,
|
||||||
|
expert_location_dispatch_info,
|
||||||
|
num_token_non_padded,
|
||||||
|
):
|
||||||
if renormalize:
|
if renormalize:
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||||
@@ -313,7 +330,6 @@ def select_experts(
|
|||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
router_logits, correction_bias = (
|
router_logits, correction_bias = (
|
||||||
expert_location_dispatch.transform_select_experts_inputs(
|
expert_location_dispatch.transform_select_experts_inputs(
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
|
|||||||
Reference in New Issue
Block a user