Minor speedup topk postprocessing (#7058)
This commit is contained in:
@@ -249,6 +249,15 @@ def _mask_topk_ids_padded_region(
|
|||||||
topk_ids[indices >= num_token_non_padded, :] = -1
|
topk_ids[indices >= num_token_non_padded, :] = -1
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
|
def _biased_grouped_topk_postprocess(
|
||||||
|
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
||||||
|
):
|
||||||
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||||
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||||
|
return topk_ids
|
||||||
|
|
||||||
|
|
||||||
def biased_grouped_topk(
|
def biased_grouped_topk(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
gating_output: torch.Tensor,
|
gating_output: torch.Tensor,
|
||||||
@@ -282,14 +291,13 @@ def biased_grouped_topk(
|
|||||||
num_fused_shared_experts,
|
num_fused_shared_experts,
|
||||||
routed_scaling_factor,
|
routed_scaling_factor,
|
||||||
)
|
)
|
||||||
# TODO merge into kernel for this branch
|
# TODO merge into kernel
|
||||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
if (expert_location_dispatch_info is not None) or (
|
||||||
# TODO will fuse this into kernel, thus use slow manual operation now
|
num_token_non_padded is not None
|
||||||
if num_token_non_padded is None:
|
):
|
||||||
return topk_weights, topk_ids
|
topk_ids = _biased_grouped_topk_postprocess(
|
||||||
torch.compile(
|
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
||||||
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
|
)
|
||||||
)(topk_ids, num_token_non_padded)
|
|
||||||
return topk_weights, topk_ids
|
return topk_weights, topk_ids
|
||||||
else:
|
else:
|
||||||
biased_grouped_topk_fn = (
|
biased_grouped_topk_fn = (
|
||||||
|
|||||||
Reference in New Issue
Block a user