Fix one missing arg in DeepEP (#6878)
This commit is contained in:
@@ -180,6 +180,9 @@ class EPMoE(torch.nn.Module):
|
|||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
assert self.num_experts % self.tp_size == 0
|
assert self.num_experts % self.tp_size == 0
|
||||||
|
assert (
|
||||||
|
num_fused_shared_experts == 0
|
||||||
|
), "num_fused_shared_experts is not supported in EP"
|
||||||
self.num_experts_per_partition = self.num_experts // self.tp_size
|
self.num_experts_per_partition = self.num_experts // self.tp_size
|
||||||
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
||||||
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
||||||
@@ -191,7 +194,6 @@ class EPMoE(torch.nn.Module):
|
|||||||
if self.use_grouped_topk:
|
if self.use_grouped_topk:
|
||||||
assert num_expert_group is not None and topk_group is not None
|
assert num_expert_group is not None and topk_group is not None
|
||||||
self.num_expert_group = num_expert_group
|
self.num_expert_group = num_expert_group
|
||||||
self.num_fused_shared_experts = num_fused_shared_experts
|
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.correction_bias = correction_bias
|
self.correction_bias = correction_bias
|
||||||
self.custom_routing_function = custom_routing_function
|
self.custom_routing_function = custom_routing_function
|
||||||
@@ -252,7 +254,6 @@ class EPMoE(torch.nn.Module):
|
|||||||
renormalize=self.renormalize,
|
renormalize=self.renormalize,
|
||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
num_expert_group=self.num_expert_group,
|
num_expert_group=self.num_expert_group,
|
||||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
||||||
correction_bias=self.correction_bias,
|
correction_bias=self.correction_bias,
|
||||||
custom_routing_function=self.custom_routing_function,
|
custom_routing_function=self.custom_routing_function,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
@@ -886,6 +887,7 @@ class DeepEPMoE(EPMoE):
|
|||||||
renormalize: bool = True,
|
renormalize: bool = True,
|
||||||
use_grouped_topk: bool = False,
|
use_grouped_topk: bool = False,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
|
num_fused_shared_experts: int = 0,
|
||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
@@ -897,23 +899,24 @@ class DeepEPMoE(EPMoE):
|
|||||||
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
num_experts,
|
num_experts=num_experts,
|
||||||
top_k,
|
top_k=top_k,
|
||||||
hidden_size,
|
hidden_size=hidden_size,
|
||||||
intermediate_size,
|
intermediate_size=intermediate_size,
|
||||||
layer_id,
|
layer_id=layer_id,
|
||||||
params_dtype,
|
params_dtype=params_dtype,
|
||||||
renormalize,
|
renormalize=renormalize,
|
||||||
use_grouped_topk,
|
use_grouped_topk=use_grouped_topk,
|
||||||
num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
topk_group,
|
num_fused_shared_experts=num_fused_shared_experts,
|
||||||
quant_config,
|
topk_group=topk_group,
|
||||||
tp_size,
|
quant_config=quant_config,
|
||||||
prefix,
|
tp_size=tp_size,
|
||||||
correction_bias,
|
prefix=prefix,
|
||||||
custom_routing_function,
|
correction_bias=correction_bias,
|
||||||
activation,
|
custom_routing_function=custom_routing_function,
|
||||||
routed_scaling_factor,
|
activation=activation,
|
||||||
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
self.deepep_mode = deepep_mode
|
self.deepep_mode = deepep_mode
|
||||||
if self.deepep_mode.enable_low_latency():
|
if self.deepep_mode.enable_low_latency():
|
||||||
|
|||||||
Reference in New Issue
Block a user