Overlap the gating function with shared experts in DeepSeek (#7978)
This commit is contained in:
@@ -437,21 +437,21 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
def forward_normal_dual_stream(
|
def forward_normal_dual_stream(
|
||||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# router_logits: (num_tokens, n_experts)
|
|
||||||
router_logits = self.gate(hidden_states)
|
|
||||||
|
|
||||||
current_stream = torch.cuda.current_stream()
|
current_stream = torch.cuda.current_stream()
|
||||||
self.alt_stream.wait_stream(current_stream)
|
self.alt_stream.wait_stream(current_stream)
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
|
|
||||||
with torch.cuda.stream(self.alt_stream):
|
with torch.cuda.stream(self.alt_stream):
|
||||||
|
# router_logits: (num_tokens, n_experts)
|
||||||
|
router_logits = self.gate(hidden_states)
|
||||||
final_hidden_states = self.experts(
|
final_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states, router_logits=router_logits
|
hidden_states=hidden_states, router_logits=router_logits
|
||||||
)
|
)
|
||||||
if not _is_cuda:
|
if not _is_cuda:
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
current_stream.wait_stream(self.alt_stream)
|
current_stream.wait_stream(self.alt_stream)
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states += shared_output
|
||||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|||||||
Reference in New Issue
Block a user