Fix DeepSeek error when using DeepEP mode (#5190)
This commit is contained in:
@@ -280,10 +280,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
return self.forward_deepep(hidden_states, forward_mode)
|
return self.forward_deepep(hidden_states, forward_mode)
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
shared_output = self.shared_experts(hidden_states)
|
|
||||||
else:
|
|
||||||
shared_output = None
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
final_hidden_states = (
|
final_hidden_states = (
|
||||||
@@ -313,8 +310,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
):
|
):
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
if self.n_shared_experts is not None:
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
shared_output = self.shared_experts(hidden_states)
|
|
||||||
topk_weights, topk_idx = select_experts(
|
topk_weights, topk_idx = select_experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
@@ -364,6 +360,12 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
|
def _forward_shared_experts(self, hidden_states):
|
||||||
|
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
||||||
|
return self.shared_experts(hidden_states)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||||
import math
|
import math
|
||||||
|
|||||||
Reference in New Issue
Block a user