diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5138e4a12..5399d6904 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -437,21 +437,21 @@ class DeepseekV2MoE(nn.Module): def forward_normal_dual_stream( self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False ) -> torch.Tensor: - # router_logits: (num_tokens, n_experts) - router_logits = self.gate(hidden_states) current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) shared_output = self._forward_shared_experts(hidden_states) with torch.cuda.stream(self.alt_stream): + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) if not _is_cuda: final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) - final_hidden_states = final_hidden_states + shared_output + final_hidden_states += shared_output if self.tp_size > 1 and not can_fuse_mlp_allreduce: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states