Fix biased_grouped_topk_cpu (#9420)
This commit is contained in:
@@ -709,8 +709,10 @@ def biased_grouped_topk_cpu(
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||
):
|
||||
assert expert_location_dispatch_info is None
|
||||
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
|
||||
Reference in New Issue
Block a user