Fix one missing arg in DeepEP (#6878)
This commit is contained in:
@@ -180,6 +180,9 @@ class EPMoE(torch.nn.Module):
|
||||
self.layer_id = layer_id
|
||||
self.num_experts = num_experts
|
||||
assert self.num_experts % self.tp_size == 0
|
||||
assert (
|
||||
num_fused_shared_experts == 0
|
||||
), "num_fused_shared_experts is not supported in EP"
|
||||
self.num_experts_per_partition = self.num_experts // self.tp_size
|
||||
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
||||
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
||||
@@ -191,7 +194,6 @@ class EPMoE(torch.nn.Module):
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.custom_routing_function = custom_routing_function
|
||||
@@ -252,7 +254,6 @@ class EPMoE(torch.nn.Module):
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
@@ -886,6 +887,7 @@ class DeepEPMoE(EPMoE):
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
@@ -897,23 +899,24 @@ class DeepEPMoE(EPMoE):
|
||||
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts,
|
||||
top_k,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
layer_id,
|
||||
params_dtype,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
quant_config,
|
||||
tp_size,
|
||||
prefix,
|
||||
correction_bias,
|
||||
custom_routing_function,
|
||||
activation,
|
||||
routed_scaling_factor,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
layer_id=layer_id,
|
||||
params_dtype=params_dtype,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
topk_group=topk_group,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=prefix,
|
||||
correction_bias=correction_bias,
|
||||
custom_routing_function=custom_routing_function,
|
||||
activation=activation,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
self.deepep_mode = deepep_mode
|
||||
if self.deepep_mode.enable_low_latency():
|
||||
|
||||
Reference in New Issue
Block a user