From 98aa836bbf3f8298bee94c652e69e54b2cabb4d9 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sat, 12 Jul 2025 13:41:50 -0700 Subject: [PATCH] Overlap the gating function with shared experts in DeepSeek (#7978) --- python/sglang/srt/models/deepseek_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 5138e4a12..5399d6904 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -437,21 +437,21 @@ class DeepseekV2MoE(nn.Module): def forward_normal_dual_stream( self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False ) -> torch.Tensor: - # router_logits: (num_tokens, n_experts) - router_logits = self.gate(hidden_states) current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) shared_output = self._forward_shared_experts(hidden_states) with torch.cuda.stream(self.alt_stream): + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) if not _is_cuda: final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) - final_hidden_states = final_hidden_states + shared_output + final_hidden_states += shared_output if self.tp_size > 1 and not can_fuse_mlp_allreduce: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states