Remove type conversion and fix id map in topk (#7759)
This commit is contained in:
@@ -112,10 +112,11 @@ def fused_topk(
|
||||
topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
gating_output.float(),
|
||||
gating_output,
|
||||
renormalize,
|
||||
)
|
||||
|
||||
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
||||
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user