Files
sglang/python/sglang/srt/managers/scheduler_input_blocker.py
2025-07-28 22:51:49 -07:00

107 lines
3.6 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
from contextlib import contextmanager
from enum import Enum, auto
from typing import Any, List, Optional
from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
from sglang.srt.poll_based_barrier import PollBasedBarrier
logger = logging.getLogger(__name__)
class SchedulerInputBlocker:
def __init__(self, noop: bool):
self._state = _State.UNBLOCKED
self._pending_reqs = []
self._noop = noop
self._global_unblock_barrier = PollBasedBarrier(noop=noop)
def handle(self, recv_reqs: Optional[List[Any]]):
assert (recv_reqs is None) == self._noop
if not self._noop:
output_reqs = []
for recv_req in recv_reqs:
output_reqs += self._handle_recv_req(recv_req)
global_arrived_unblock_barrier = (
self._global_unblock_barrier.poll_global_arrived()
)
if (
self._state == _State.GLOBAL_UNBLOCK_BARRIER
and global_arrived_unblock_barrier
):
output_reqs += self._handle_arrive_unblock_barrier()
if not self._noop:
return output_reqs
def _handle_recv_req(self, recv_req):
if isinstance(recv_req, BlockReqInput):
if recv_req.type == BlockReqType.BLOCK:
self._execute_block_req()
return []
elif recv_req.type == BlockReqType.UNBLOCK:
self._execute_unblock_req()
return []
else:
raise NotImplementedError(f"{recv_req=}")
else:
if self._state == _State.UNBLOCKED:
return [recv_req]
else:
self._pending_reqs.append(recv_req)
return []
def _execute_block_req(self):
logger.info("Handle block req")
self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)
def _execute_unblock_req(self):
logger.info("Handle unblock req")
self._change_state(
original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
)
self._global_unblock_barrier.local_arrive()
def _handle_arrive_unblock_barrier(self):
logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
self._change_state(
original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
)
output_reqs = [*self._pending_reqs]
self._pending_reqs.clear()
return output_reqs
def _change_state(self, original: "_State", target: "_State"):
assert self._state == original, f"{self._state=} {original=} {target=}"
self._state = target
class _State(Enum):
UNBLOCKED = auto()
BLOCKED = auto()
GLOBAL_UNBLOCK_BARRIER = auto()
@contextmanager
def input_blocker_guard_region(send_to_scheduler):
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
try:
yield
finally:
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))