diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6fd6ffe64..037731c70 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -120,6 +120,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import ( SchedulerOutputProcessorMixin, ) from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin +from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper from sglang.srt.managers.scheduler_update_weights_mixin import ( SchedulerUpdateWeightsMixin, ) @@ -474,6 +475,7 @@ class Scheduler( ) self.init_profier() + self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args) self.input_blocker = ( SchedulerInputBlocker(noop=self.attn_tp_rank != 0) if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN") @@ -946,6 +948,14 @@ class Scheduler( def recv_requests(self) -> List[Req]: """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" + + if self.recv_skipper is not None: + last_forward_mode = ( + self.last_batch.forward_mode if self.last_batch is not None else None + ) + if not self.recv_skipper.handle(last_forward_mode): + return [] + if self.pp_rank == 0: if self.attn_tp_rank == 0: recv_reqs = [] diff --git a/python/sglang/srt/managers/scheduler_recv_skipper.py b/python/sglang/srt/managers/scheduler_recv_skipper.py new file mode 100644 index 000000000..f0550c935 --- /dev/null +++ b/python/sglang/srt/managers/scheduler_recv_skipper.py @@ -0,0 +1,37 @@ +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.server_args import ServerArgs + + +class SchedulerRecvSkipper: + @staticmethod + def maybe_create(server_args: ServerArgs): + if server_args.scheduler_recv_interval <= 1: + return None + return SchedulerRecvSkipper(server_args) + + def __init__(self, server_args: ServerArgs): + # Can be supported if needed, but may need e.g. `global_forward_mode` + assert not server_args.enable_dp_attention + self._counter = 0 + self._threshold = server_args.scheduler_recv_interval + + def handle(self, last_forward_mode: ForwardMode): + should_recv = False + + last_weight = _WEIGHT_OF_FORWARD_MODE.get(last_forward_mode, _DEFAULT_WEIGHT) + self._counter += last_weight + + if self._counter >= self._threshold: + self._counter = 0 + should_recv = True + + return should_recv + + +# All can be tuned if needed +_DEFAULT_WEIGHT = 1000 +_WEIGHT_OF_FORWARD_MODE = { + ForwardMode.DECODE: 1, + ForwardMode.TARGET_VERIFY: 1, + None: 1, +} diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 217abc337..30a210980 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -249,6 +249,7 @@ class ServerArgs: enable_return_hidden_states: bool = False enable_triton_kernel_moe: bool = False enable_flashinfer_mxfp4_moe: bool = False + scheduler_recv_interval: int = 1 # Debug tensor dumps debug_tensor_dump_output_folder: Optional[str] = None @@ -1845,6 +1846,12 @@ class ServerArgs: action="store_true", help="Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.", ) + parser.add_argument( + "--scheduler-recv-interval", + type=int, + default=ServerArgs.scheduler_recv_interval, + help="The interval to poll requests in scheduler. Can be set to >1 to reduce the overhead of this.", + ) # Debug tensor dumps parser.add_argument(