Set num_fused_shared_experts as num_shared_experts when shared_experts fusion is not disabled (#6736)
This commit is contained in:
@@ -156,6 +156,7 @@ class EPMoE(torch.nn.Module):
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
@@ -190,6 +191,7 @@ class EPMoE(torch.nn.Module):
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.custom_routing_function = custom_routing_function
|
||||
@@ -250,6 +252,7 @@ class EPMoE(torch.nn.Module):
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
|
||||
Reference in New Issue
Block a user