Set num_fused_shared_experts as num_shared_experts when shared_experts fusion is not disabled (#6736)
This commit is contained in:
@@ -303,6 +303,7 @@ def select_experts(
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
torch_native: bool = False,
|
||||
@@ -310,7 +311,6 @@ def select_experts(
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
num_fused_shared_experts = global_server_args_dict["num_fused_shared_experts"]
|
||||
|
||||
router_logits, correction_bias = (
|
||||
expert_location_dispatch.transform_select_experts_inputs(
|
||||
|
||||
Reference in New Issue
Block a user