diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index af82d54a4..4ecc5535b 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch +from sglang.srt import single_batch_overlap from sglang.srt.layers.moe import ( get_deepep_mode, get_moe_a2a_backend, @@ -167,18 +168,20 @@ class DeepEPMoE(FusedMoE): topk_idx: torch.Tensor, topk_weights: torch.Tensor, forward_batch: ForwardBatch, + forward_shared_experts=None, + alt_stream=None, ): - dispatch_output = self.dispatch( - hidden_states, topk_idx, topk_weights, forward_batch + # We have to call SBO inside MoE to be compatible with hooks used in offloading + return single_batch_overlap.execute_sbo( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + forward_batch=forward_batch, + # SBO args + experts=self, + forward_shared_experts=forward_shared_experts, + alt_stream=alt_stream, ) - hidden_states = self.moe_impl(dispatch_output) - hidden_states = self.combine( - hidden_states, - dispatch_output.topk_idx, - dispatch_output.topk_weights, - forward_batch, - ) - return hidden_states def dispatch( self, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 3cbaf3ed2..5bf8da10e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -872,7 +872,7 @@ class DeepseekV2MoE(nn.Module): if hidden_states.shape[0] > 0: # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - if not SboFlags.fuse_shared_experts_inside_sbo(): + if not self._fuse_shared_experts_inside_sbo: shared_output = self._forward_shared_experts(hidden_states) topk_weights, topk_idx, _ = self.topk( hidden_states, @@ -887,18 +887,27 @@ class DeepseekV2MoE(nn.Module): hidden_states.device ) - final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo( + if self._fuse_shared_experts_inside_sbo: + shared_output = None + + def _forward_shared_experts_and_put_results(): + nonlocal shared_output + shared_output = self._forward_shared_experts(hidden_states) + + final_hidden_states = self.experts( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, forward_batch=forward_batch, - # SBO args - forward_shared_experts=lambda: self._forward_shared_experts(hidden_states), - experts=self.experts, - alt_stream=self.alt_stream, + **( + dict( + forward_shared_experts=_forward_shared_experts_and_put_results, + alt_stream=self.alt_stream, + ) + if self._fuse_shared_experts_inside_sbo + else {} + ), ) - if sbo_shared_output is not None: - shared_output = sbo_shared_output if shared_output is not None: x = shared_output diff --git a/python/sglang/srt/single_batch_overlap.py b/python/sglang/srt/single_batch_overlap.py index c56af1731..dd2be4885 100644 --- a/python/sglang/srt/single_batch_overlap.py +++ b/python/sglang/srt/single_batch_overlap.py @@ -42,7 +42,7 @@ class CombineOverlapArgs: wait_event: torch.cuda.Event num_sms: int signal: Optional[torch.Tensor] = None - threshold: int = -1 + threshold: int = 0 @dataclass @@ -61,8 +61,6 @@ def execute_sbo( forward_batch: ForwardBatch, alt_stream: Optional = None, ): - shared_output = None - dispatch_output = experts.dispatch( hidden_states, topk_idx, topk_weights, forward_batch ) @@ -82,7 +80,7 @@ def execute_sbo( with deep_gemm_wrapper.configure_deep_gemm_num_sms( meta_overlap_args["compute_num_sms"] ): - shared_output = forward_shared_experts() + forward_shared_experts() hidden_states = experts.combine( hidden_states, @@ -92,7 +90,7 @@ def execute_sbo( overlap_args=combine_overlap_args, ) - return hidden_states, shared_output + return hidden_states def _compute_overlap_args(dispatch_output, alt_stream):