Add activation parameters to fused_moe (#3170)
This commit is contained in:
@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -140,6 +141,7 @@ class EPMoE(torch.nn.Module):
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.activation = activation
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
||||
@@ -166,6 +168,7 @@ class EPMoE(torch.nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
|
||||
if self.grouped_gemm_runner is None:
|
||||
self.grouped_gemm_runner = GroupedGemmRunner(
|
||||
|
||||
Reference in New Issue
Block a user