diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 0aa24f461..c924494b0 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -170,6 +170,7 @@ class DeepEPMoE(FusedMoE): forward_batch: ForwardBatch, forward_shared_experts=None, alt_stream=None, + disable_sbo=False, ): # We have to call SBO inside MoE to be compatible with hooks used in offloading return single_batch_overlap.execute_sbo( @@ -181,6 +182,7 @@ class DeepEPMoE(FusedMoE): experts=self, forward_shared_experts=forward_shared_experts, alt_stream=alt_stream, + disable_sbo=disable_sbo, ) def dispatch( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 327e04c65..e796809a0 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -902,6 +902,8 @@ class DeepseekV2MoE(nn.Module): dict( forward_shared_experts=_forward_shared_experts_and_put_results, alt_stream=self.alt_stream, + # SBO is not yet implemented for NextN + disable_sbo=self.is_nextn, ) if self._fuse_shared_experts_inside_sbo else {} diff --git a/python/sglang/srt/single_batch_overlap.py b/python/sglang/srt/single_batch_overlap.py index 885f750ff..77cd41d9d 100644 --- a/python/sglang/srt/single_batch_overlap.py +++ b/python/sglang/srt/single_batch_overlap.py @@ -60,13 +60,14 @@ def execute_sbo( topk_weights: torch.Tensor, forward_batch: ForwardBatch, alt_stream: Optional = None, + disable_sbo: bool = False, ): dispatch_output = experts.dispatch( hidden_states, topk_idx, topk_weights, forward_batch ) combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = ( - _compute_overlap_args(dispatch_output, alt_stream) + _compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo) ) hidden_states = experts.moe_impl( @@ -75,7 +76,7 @@ def execute_sbo( if (e := meta_overlap_args.get("record_event_after_down")) is not None: e.record() - if SboFlags.enable_combine_shared_two_stream_overlap(): + if (not disable_sbo) and SboFlags.enable_combine_shared_two_stream_overlap(): # TODO reduce sm for non-deepgemm with deep_gemm_wrapper.configure_deep_gemm_num_sms( meta_overlap_args["compute_num_sms"] @@ -93,8 +94,8 @@ def execute_sbo( return hidden_states -def _compute_overlap_args(dispatch_output, alt_stream): - if not ( +def _compute_overlap_args(dispatch_output, alt_stream, disable_sbo): + if disable_sbo or not ( SboFlags.enable_combine_down_gemm_two_stream_overlap() or SboFlags.enable_combine_shared_two_stream_overlap() ):