Fix topk inference performance reduce (#6474)
This commit is contained in:
@@ -264,6 +264,8 @@ def biased_grouped_topk(
|
|||||||
# TODO merge into kernel for this branch
|
# TODO merge into kernel for this branch
|
||||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||||
# TODO will fuse this into kernel, thus use slow manual operation now
|
# TODO will fuse this into kernel, thus use slow manual operation now
|
||||||
|
if num_token_non_padded is None:
|
||||||
|
return topk_weights, topk_ids
|
||||||
torch.compile(
|
torch.compile(
|
||||||
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
|
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
|
||||||
)(topk_ids, num_token_non_padded)
|
)(topk_ids, num_token_non_padded)
|
||||||
|
|||||||
Reference in New Issue
Block a user