From 5e786cca3a9a19ac7807144d0035013c85d04d33 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 2 Oct 2025 18:04:36 +0800 Subject: [PATCH] Support single batch overlap (#10422) --- docs/advanced_features/server_arguments.md | 1 + python/sglang/srt/layers/moe/ep_moe/layer.py | 19 ++- .../srt/layers/moe/flashinfer_cutedsl_moe.py | 14 ++ .../srt/layers/moe/token_dispatcher/deepep.py | 61 +++++-- python/sglang/srt/layers/moe/utils.py | 10 ++ .../srt/layers/quantization/modelopt_quant.py | 11 ++ python/sglang/srt/models/deepseek_v2.py | 15 +- python/sglang/srt/server_args.py | 6 + python/sglang/srt/single_batch_overlap.py | 151 ++++++++++++++++++ 9 files changed, 268 insertions(+), 20 deletions(-) create mode 100644 python/sglang/srt/single_batch_overlap.py diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 1d8bcaaf5..4e10a6402 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -294,6 +294,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--enable-dp-lm-head` | Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention. | False | | `--enable-two-batch-overlap` | Enabling two micro batches to overlap. | False | | `--tbo-token-distribution-threshold` | The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap. | 0.48 | +| `--enable-single-batch-overlap` | Enabling single batch overlap. | False | | `--enable-torch-compile` | Optimize the model with torch.compile. Experimental feature. | False | | `--torch-compile-max-bs` | Set the maximum batch size when using torch compile. | 32 | | `--torchao-config` | Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. | | diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 287bc00fc..5f2813da8 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, List, Optional, Union +from contextlib import nullcontext +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import torch import triton @@ -38,10 +39,12 @@ from sglang.srt.layers.quantization.modelopt_quant import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.offloader import get_offloader +from sglang.srt.single_batch_overlap import DownGemmOverlapArgs from sglang.srt.utils import ( ceil_div, dispose_tensor, get_bool_env_var, + get_int_env_var, is_cuda, is_hip, is_npu, @@ -466,7 +469,11 @@ class DeepEPMoE(EPMoE): ), ) - def moe_impl(self, dispatch_output: DispatchOutput): + def moe_impl( + self, + dispatch_output: DispatchOutput, + down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None, + ): from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker if _use_aiter: @@ -481,7 +488,9 @@ class DeepEPMoE(EPMoE): return self.forward_deepgemm_contiguous(dispatch_output) elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): if get_moe_runner_backend().is_flashinfer_cutedsl(): - return self.forward_flashinfer_cutedsl(dispatch_output) + return self.forward_flashinfer_cutedsl( + dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args + ) assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 return self.forward_deepgemm_masked(dispatch_output) else: @@ -495,12 +504,14 @@ class DeepEPMoE(EPMoE): topk_idx: torch.Tensor, topk_weights: torch.Tensor, forward_batch: ForwardBatch, + overlap_args: Optional[Dict[str, Any]] = None, ): return self.deepep_dispatcher.combine( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, forward_batch=forward_batch, + overlap_args=overlap_args, ) def forward_aiter( @@ -687,6 +698,7 @@ class DeepEPMoE(EPMoE): def forward_flashinfer_cutedsl( self, dispatch_output: DeepEPLLOutput, + down_gemm_overlap_args: Optional[DownGemmOverlapArgs], ): hidden_states, _, _, masked_m, _ = dispatch_output assert self.quant_method is not None @@ -697,6 +709,7 @@ class DeepEPMoE(EPMoE): x=hidden_states, masked_m=masked_m, moe_runner_config=self.moe_runner_config, + down_gemm_overlap_args=down_gemm_overlap_args, ) return output diff --git a/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py b/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py index f96361ecb..1d37236e0 100644 --- a/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +++ b/python/sglang/srt/layers/moe/flashinfer_cutedsl_moe.py @@ -30,6 +30,9 @@ def flashinfer_cutedsl_moe_masked( w2_blockscale: torch.Tensor, w2_alpha, masked_m: torch.Tensor, + down_sm_count: Optional[int] = None, + down_signals: Optional[torch.Tensor] = None, + down_start_event: Optional[torch.cuda.Event] = None, ): """ Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL @@ -151,6 +154,9 @@ def flashinfer_cutedsl_moe_masked( masked_m, ) + if down_start_event is not None: + down_start_event.record() + # Gemm2 out = torch.empty((num_experts, m, k), dtype=torch.bfloat16, device=a_q.device) out = out.permute(1, 2, 0) # requirement of kernel @@ -165,5 +171,13 @@ def flashinfer_cutedsl_moe_masked( sf_vec_size=sf_vec_size, alpha=w2_alpha.view(1, 1, num_experts), alpha_dtype=get_cute_dtype(w2_alpha), + **( + dict( + sm_count=down_sm_count, + dst_signals=down_signals, + ) + if down_sm_count is not None or down_signals is not None + else {} + ), ) # in logical [m, k, l] return out.permute(2, 0, 1) diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index da09f022b..5e980f472 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -1,8 +1,9 @@ from __future__ import annotations import logging +from contextlib import nullcontext from dataclasses import dataclass -from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.moe.token_dispatcher.base import ( @@ -25,6 +26,9 @@ from sglang.srt.utils import ( _is_npu = is_npu() +if TYPE_CHECKING: + from sglang.srt.single_batch_overlap import CombineOverlapArgs + try: from deep_ep import Buffer, Config @@ -310,6 +314,7 @@ class _DeepEPDispatcherImplBase: hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, + overlap_args: Optional["CombineOverlapArgs"], ): raise NotImplementedError @@ -428,6 +433,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, + overlap_args: Optional["CombineOverlapArgs"], ): from sglang.srt.layers.moe.ep_moe.kernels import ( deepep_post_reorder_triton_kernel, @@ -503,6 +509,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding """ self.return_recv_hook = return_recv_hook + self.device_module = torch.get_device_module() def dispatch_a( self, @@ -570,7 +577,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): use_fp8 = True buffer = self._get_buffer() - packed_recv_hidden, packed_recv_count, self.handle, event, hook = ( + packed_recv_hidden, self.packed_recv_count, self.handle, event, hook = ( buffer.low_latency_dispatch( hidden_states, topk_idx, @@ -591,23 +598,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): and deep_gemm_wrapper.DEEPGEMM_BLACKWELL, ) ) - return packed_recv_hidden, packed_recv_count, event, hook + return packed_recv_hidden, self.packed_recv_count, event, hook def combine_a( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, + overlap_args: Optional["CombineOverlapArgs"], ): hidden_states, event, hook = self._combine_core( hidden_states, topk_idx, topk_weights, + overlap_args=overlap_args, ) - return hidden_states, event, hook + return hidden_states, event, hook, overlap_args - def combine_b(self, hidden_states, event, hook): + def combine_b(self, hidden_states, event, hook, overlap_args): hook() if self.return_recv_hook else event.current_stream_wait() + + if overlap_args is not None: + self.device_module.current_stream().wait_stream(overlap_args.stream) + return hidden_states def _combine_core( @@ -615,17 +628,35 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, + overlap_args: Optional["CombineOverlapArgs"], ): buffer = self._get_buffer() - combined_hidden_states, event, hook = buffer.low_latency_combine( - hidden_states, - topk_idx, - topk_weights, - self.handle, - async_finish=not self.return_recv_hook, - return_recv_hook=self.return_recv_hook, - ) - self.handle = None + + ctx = nullcontext() + if overlap_args is not None: + overlap_args.stream.wait_event(overlap_args.wait_event) + ctx = torch.cuda.stream(overlap_args.stream) + + with ctx: + combined_hidden_states, event, hook = buffer.low_latency_combine( + x=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + handle=self.handle, + async_finish=not self.return_recv_hook, + return_recv_hook=self.return_recv_hook, + **( + dict( + overlap=overlap_args.overlap, + src_signals=overlap_args.signal, + src_signal_expect_value=overlap_args.threshold, + ) + if overlap_args is not None + else {} + ), + ) + + self.packed_recv_count = self.handle = None return combined_hidden_states, event, hook def _get_buffer(self): @@ -727,12 +758,14 @@ class DeepEPDispatcher(BaseDispatcher): topk_idx: torch.Tensor, topk_weights: torch.Tensor, forward_batch: ForwardBatch, + overlap_args: Optional["CombineOverlapArgs"] = None, ): self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A) inner_state = self._get_impl(forward_batch).combine_a( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, + overlap_args=overlap_args, ) self._combine_intermediate_state = forward_batch, inner_state diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index fa136d19c..a70d1be40 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -108,6 +108,7 @@ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None DEEPEP_MODE: Optional[DeepEPMode] = None IS_TBO_ENABLED: Optional[bool] = None +IS_SBO_ENABLED: Optional[bool] = None TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None DEEPEP_CONFIG: Optional[str] = None DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None @@ -119,6 +120,7 @@ def initialize_moe_config(server_args: ServerArgs): global DEEPEP_MODE global DEEPEP_CONFIG global IS_TBO_ENABLED + global IS_SBO_ENABLED global TBO_TOKEN_DISTRIBUTION_THRESHOLD global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER @@ -127,6 +129,7 @@ def initialize_moe_config(server_args: ServerArgs): DEEPEP_MODE = DeepEPMode(server_args.deepep_mode) DEEPEP_CONFIG = server_args.deepep_config or "" IS_TBO_ENABLED = server_args.enable_two_batch_overlap + IS_SBO_ENABLED = server_args.enable_single_batch_overlap TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = ( server_args.disable_flashinfer_cutlass_moe_fp4_allgather @@ -172,6 +175,13 @@ def is_tbo_enabled() -> bool: return IS_TBO_ENABLED +def is_sbo_enabled() -> bool: + global IS_SBO_ENABLED + if IS_SBO_ENABLED is None: + IS_SBO_ENABLED = False + return IS_SBO_ENABLED + + def get_tbo_token_distribution_threshold() -> float: global TBO_TOKEN_DISTRIBUTION_THRESHOLD if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None: diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 27a2ea950..7a40d6953 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -47,6 +47,7 @@ if TYPE_CHECKING: CombineInput, StandardDispatchOutput, ) + from sglang.srt.single_batch_overlap import DownGemmOverlapArgs if is_cuda(): from sgl_kernel import scaled_fp4_quant @@ -1468,6 +1469,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): x: torch.Tensor, masked_m: torch.Tensor, moe_runner_config: MoeRunnerConfig, + down_gemm_overlap_args: Optional["DownGemmOverlapArgs"], ) -> torch.Tensor: assert ( moe_runner_config.activation == "silu" @@ -1495,5 +1497,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): w2_blockscale=layer.w2_blockscale_swizzled, w2_alpha=layer.g2_alphas, masked_m=masked_m, + **( + dict( + down_sm_count=down_gemm_overlap_args.num_sms, + down_signals=down_gemm_overlap_args.signal, + down_start_event=down_gemm_overlap_args.start_event, + ) + if down_gemm_overlap_args is not None + else {} + ), ) return out diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index b486740c3..131786946 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -28,6 +28,7 @@ from torch import nn from tqdm import tqdm from transformers import PretrainedConfig +from sglang.srt import single_batch_overlap from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, get_pp_group, @@ -101,6 +102,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.single_batch_overlap import SboFlags from sglang.srt.two_batch_overlap import ( MaybeTboDeepEPDispatcher, model_forward_maybe_tbo, @@ -806,7 +808,8 @@ class DeepseekV2MoE(nn.Module): if hidden_states.shape[0] > 0: # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - shared_output = self._forward_shared_experts(hidden_states) + if not SboFlags.fuse_shared_experts_inside_sbo(): + shared_output = self._forward_shared_experts(hidden_states) topk_weights, topk_idx, _ = self.topk( hidden_states, router_logits, @@ -820,12 +823,18 @@ class DeepseekV2MoE(nn.Module): hidden_states.device ) - final_hidden_states = self.experts( + final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, forward_batch=forward_batch, + # SBO args + forward_shared_experts=lambda: self._forward_shared_experts(hidden_states), + experts=self.experts, + alt_stream=self.alt_stream, ) + if sbo_shared_output is not None: + shared_output = sbo_shared_output if shared_output is not None: x = shared_output @@ -843,7 +852,7 @@ class DeepseekV2MoE(nn.Module): def _forward_shared_experts( self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None ): - if self.num_fused_shared_experts == 0: + if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0): return self.shared_experts( hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index dc1fbd2db..84c92983e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -377,6 +377,7 @@ class ServerArgs: enable_dp_attention: bool = False enable_dp_lm_head: bool = False enable_two_batch_overlap: bool = False + enable_single_batch_overlap: bool = False tbo_token_distribution_threshold: float = 0.48 enable_torch_compile: bool = False torch_compile_max_bs: int = 32 @@ -2457,6 +2458,11 @@ class ServerArgs: action="store_true", help="Enabling two micro batches to overlap.", ) + parser.add_argument( + "--enable-single-batch-overlap", + action="store_true", + help="Let computation and communication overlap within one micro batch.", + ) parser.add_argument( "--tbo-token-distribution-threshold", type=float, diff --git a/python/sglang/srt/single_batch_overlap.py b/python/sglang/srt/single_batch_overlap.py new file mode 100644 index 000000000..b8839c68f --- /dev/null +++ b/python/sglang/srt/single_batch_overlap.py @@ -0,0 +1,151 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch + +from sglang.srt.layers.moe import get_moe_runner_backend +from sglang.srt.layers.moe.utils import is_sbo_enabled +from sglang.srt.layers.quantization import deep_gemm_wrapper +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import get_int_env_var + +if TYPE_CHECKING: + from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE + + +class SboFlags: + # TODO may have: "enable_dispatch_shared_one_stream_overlap", "enable_dispatch_gateup_gemm_two_stream_overlap", ... + + @classmethod + def enable_combine_down_gemm_two_stream_overlap(cls): + return ( + is_sbo_enabled() + # currently only cutedsl backend supports it + and get_moe_runner_backend().is_flashinfer_cutedsl() + ) + + @classmethod + def enable_combine_shared_two_stream_overlap(cls): + return is_sbo_enabled() + + @classmethod + def fuse_shared_experts_inside_sbo(cls): + # TODO after antgroup's PR, should be `... or cls.enable_dispatch_shared_one_stream_overlap()` + return cls.enable_combine_shared_two_stream_overlap() + + +@dataclass +class CombineOverlapArgs: + # this "overlap" flag means overlapping with down gemm, not the general two-stream overlap + overlap: bool + stream: torch.cuda.Stream + wait_event: torch.cuda.Event + num_sms: int + signal: Optional[torch.Tensor] = None + threshold: int = -1 + + +@dataclass +class DownGemmOverlapArgs: + num_sms: int + signal: torch.Tensor + start_event: torch.cuda.Event + + +def execute_sbo( + forward_shared_experts: Callable[[], Any], + experts: "DeepEPMoE", + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_batch: ForwardBatch, + alt_stream: Optional = None, +): + shared_output = None + + dispatch_output = experts.dispatch( + hidden_states, topk_idx, topk_weights, forward_batch + ) + + combine_overlap_args, down_gemm_overlap_args, meta_overlap_args = ( + _compute_overlap_args(dispatch_output, alt_stream) + ) + + hidden_states = experts.moe_impl( + dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args + ) + if (e := meta_overlap_args.get("record_event_after_down")) is not None: + e.record() + + if SboFlags.enable_combine_shared_two_stream_overlap(): + # TODO reduce sm for non-deepgemm + with deep_gemm_wrapper.configure_deep_gemm_num_sms( + meta_overlap_args["compute_num_sms"] + ): + shared_output = forward_shared_experts() + + hidden_states = experts.combine( + hidden_states, + dispatch_output.topk_idx, + dispatch_output.topk_weights, + forward_batch, + overlap_args=combine_overlap_args, + ) + + return hidden_states, shared_output + + +def _compute_overlap_args(dispatch_output, alt_stream): + if not ( + SboFlags.enable_combine_down_gemm_two_stream_overlap() + or SboFlags.enable_combine_shared_two_stream_overlap() + ): + return None, None, {} + + hidden_states = dispatch_output.hidden_states_fp8 + if isinstance(hidden_states, tuple): + hidden_states = hidden_states[0] + + num_local_experts, num_tokens_static, hidden_dim = hidden_states.shape + + total_num_sms = torch.cuda.get_device_properties( + device="cuda" + ).multi_processor_count + communicate_num_sms = get_int_env_var("SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS", 32) + compute_num_sms = total_num_sms - communicate_num_sms + + assert alt_stream is not None + combine_wait_event = torch.cuda.Event() + combine_overlap_args = CombineOverlapArgs( + overlap=False, + num_sms=communicate_num_sms, + stream=alt_stream, + wait_event=combine_wait_event, + ) + meta_overlap_args = dict( + compute_num_sms=compute_num_sms, + ) + down_gemm_overlap_args = None + + if SboFlags.enable_combine_down_gemm_two_stream_overlap(): + # TODO use zero_allocator to remove this `torch.zeros` call + # NOTE ours v2 use uint32 not int32 currently + combine_signal = torch.zeros( + num_local_experts, dtype=torch.uint32, device=hidden_states.device + ) + + down_gemm_overlap_args = DownGemmOverlapArgs( + signal=combine_signal, + start_event=combine_wait_event, + num_sms=compute_num_sms, + ) + combine_overlap_args.overlap = True + combine_overlap_args.signal = combine_signal + combine_overlap_args.threshold = compute_num_sms + else: + meta_overlap_args |= dict( + record_event_after_down=combine_wait_event, + ) + + return combine_overlap_args, down_gemm_overlap_args, meta_overlap_args