Reduce scheduler recv requests overhead (#8947)
This commit is contained in:
@@ -120,6 +120,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
|
|||||||
SchedulerOutputProcessorMixin,
|
SchedulerOutputProcessorMixin,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
||||||
|
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
|
||||||
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
||||||
SchedulerUpdateWeightsMixin,
|
SchedulerUpdateWeightsMixin,
|
||||||
)
|
)
|
||||||
@@ -474,6 +475,7 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
self.init_profier()
|
self.init_profier()
|
||||||
|
|
||||||
|
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
||||||
self.input_blocker = (
|
self.input_blocker = (
|
||||||
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
|
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
|
||||||
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
||||||
@@ -946,6 +948,14 @@ class Scheduler(
|
|||||||
|
|
||||||
def recv_requests(self) -> List[Req]:
|
def recv_requests(self) -> List[Req]:
|
||||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||||
|
|
||||||
|
if self.recv_skipper is not None:
|
||||||
|
last_forward_mode = (
|
||||||
|
self.last_batch.forward_mode if self.last_batch is not None else None
|
||||||
|
)
|
||||||
|
if not self.recv_skipper.handle(last_forward_mode):
|
||||||
|
return []
|
||||||
|
|
||||||
if self.pp_rank == 0:
|
if self.pp_rank == 0:
|
||||||
if self.attn_tp_rank == 0:
|
if self.attn_tp_rank == 0:
|
||||||
recv_reqs = []
|
recv_reqs = []
|
||||||
|
|||||||
37
python/sglang/srt/managers/scheduler_recv_skipper.py
Normal file
37
python/sglang/srt/managers/scheduler_recv_skipper.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerRecvSkipper:
|
||||||
|
@staticmethod
|
||||||
|
def maybe_create(server_args: ServerArgs):
|
||||||
|
if server_args.scheduler_recv_interval <= 1:
|
||||||
|
return None
|
||||||
|
return SchedulerRecvSkipper(server_args)
|
||||||
|
|
||||||
|
def __init__(self, server_args: ServerArgs):
|
||||||
|
# Can be supported if needed, but may need e.g. `global_forward_mode`
|
||||||
|
assert not server_args.enable_dp_attention
|
||||||
|
self._counter = 0
|
||||||
|
self._threshold = server_args.scheduler_recv_interval
|
||||||
|
|
||||||
|
def handle(self, last_forward_mode: ForwardMode):
|
||||||
|
should_recv = False
|
||||||
|
|
||||||
|
last_weight = _WEIGHT_OF_FORWARD_MODE.get(last_forward_mode, _DEFAULT_WEIGHT)
|
||||||
|
self._counter += last_weight
|
||||||
|
|
||||||
|
if self._counter >= self._threshold:
|
||||||
|
self._counter = 0
|
||||||
|
should_recv = True
|
||||||
|
|
||||||
|
return should_recv
|
||||||
|
|
||||||
|
|
||||||
|
# All can be tuned if needed
|
||||||
|
_DEFAULT_WEIGHT = 1000
|
||||||
|
_WEIGHT_OF_FORWARD_MODE = {
|
||||||
|
ForwardMode.DECODE: 1,
|
||||||
|
ForwardMode.TARGET_VERIFY: 1,
|
||||||
|
None: 1,
|
||||||
|
}
|
||||||
@@ -249,6 +249,7 @@ class ServerArgs:
|
|||||||
enable_return_hidden_states: bool = False
|
enable_return_hidden_states: bool = False
|
||||||
enable_triton_kernel_moe: bool = False
|
enable_triton_kernel_moe: bool = False
|
||||||
enable_flashinfer_mxfp4_moe: bool = False
|
enable_flashinfer_mxfp4_moe: bool = False
|
||||||
|
scheduler_recv_interval: int = 1
|
||||||
|
|
||||||
# Debug tensor dumps
|
# Debug tensor dumps
|
||||||
debug_tensor_dump_output_folder: Optional[str] = None
|
debug_tensor_dump_output_folder: Optional[str] = None
|
||||||
@@ -1845,6 +1846,12 @@ class ServerArgs:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
|
help="Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scheduler-recv-interval",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.scheduler_recv_interval,
|
||||||
|
help="The interval to poll requests in scheduler. Can be set to >1 to reduce the overhead of this.",
|
||||||
|
)
|
||||||
|
|
||||||
# Debug tensor dumps
|
# Debug tensor dumps
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user