Fix DeepSeek error when using DeepEP mode (#5190)
This commit is contained in:
@@ -280,10 +280,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
return self.forward_deepep(hidden_states, forward_mode)
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
else:
|
||||
shared_output = None
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
final_hidden_states = (
|
||||
@@ -313,8 +310,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
topk_weights, topk_idx = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
@@ -364,6 +360,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def _forward_shared_experts(self, hidden_states):
|
||||
if self.n_shared_experts is not None and self.n_share_experts_fusion == 0:
|
||||
return self.shared_experts(hidden_states)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
import math
|
||||
|
||||
Reference in New Issue
Block a user