Fuse routed_scaling_factor in DeepSeek (#6710)
This commit is contained in:
@@ -526,9 +526,13 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
def op_output(self, state):
|
def op_output(self, state):
|
||||||
final_hidden_states = state.pop("hidden_states_after_combine")
|
final_hidden_states = state.pop("hidden_states_after_combine")
|
||||||
|
|
||||||
|
if (shared_output := state.pop("shared_output")) is not None:
|
||||||
|
x = shared_output
|
||||||
|
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
||||||
|
final_hidden_states = x
|
||||||
|
else:
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
if (s := state.pop("shared_output")) is not None:
|
|
||||||
final_hidden_states = final_hidden_states + s
|
|
||||||
|
|
||||||
state.hidden_states_mlp_output = final_hidden_states
|
state.hidden_states_mlp_output = final_hidden_states
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user