Fix biased_grouped_topk_cpu (#9420)
This commit is contained in:
@@ -709,8 +709,10 @@ def biased_grouped_topk_cpu(
|
|||||||
routed_scaling_factor: Optional[float] = None,
|
routed_scaling_factor: Optional[float] = None,
|
||||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||||
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
||||||
):
|
):
|
||||||
assert expert_location_dispatch_info is None
|
assert expert_location_dispatch_info is None
|
||||||
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
||||||
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
gating_output,
|
gating_output,
|
||||||
|
|||||||
Reference in New Issue
Block a user