107 lines
3.6 KiB
Python
107 lines
3.6 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
import logging
|
|
from contextlib import contextmanager
|
|
from enum import Enum, auto
|
|
from typing import Any, List, Optional
|
|
|
|
from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
|
|
from sglang.srt.poll_based_barrier import PollBasedBarrier
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SchedulerInputBlocker:
|
|
def __init__(self, noop: bool):
|
|
self._state = _State.UNBLOCKED
|
|
self._pending_reqs = []
|
|
self._noop = noop
|
|
self._global_unblock_barrier = PollBasedBarrier(noop=noop)
|
|
|
|
def handle(self, recv_reqs: Optional[List[Any]]):
|
|
assert (recv_reqs is None) == self._noop
|
|
|
|
if not self._noop:
|
|
output_reqs = []
|
|
for recv_req in recv_reqs:
|
|
output_reqs += self._handle_recv_req(recv_req)
|
|
|
|
global_arrived_unblock_barrier = (
|
|
self._global_unblock_barrier.poll_global_arrived()
|
|
)
|
|
if (
|
|
self._state == _State.GLOBAL_UNBLOCK_BARRIER
|
|
and global_arrived_unblock_barrier
|
|
):
|
|
output_reqs += self._handle_arrive_unblock_barrier()
|
|
|
|
if not self._noop:
|
|
return output_reqs
|
|
|
|
def _handle_recv_req(self, recv_req):
|
|
if isinstance(recv_req, BlockReqInput):
|
|
if recv_req.type == BlockReqType.BLOCK:
|
|
self._execute_block_req()
|
|
return []
|
|
elif recv_req.type == BlockReqType.UNBLOCK:
|
|
self._execute_unblock_req()
|
|
return []
|
|
else:
|
|
raise NotImplementedError(f"{recv_req=}")
|
|
else:
|
|
if self._state == _State.UNBLOCKED:
|
|
return [recv_req]
|
|
else:
|
|
self._pending_reqs.append(recv_req)
|
|
return []
|
|
|
|
def _execute_block_req(self):
|
|
logger.info("Handle block req")
|
|
self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)
|
|
|
|
def _execute_unblock_req(self):
|
|
logger.info("Handle unblock req")
|
|
self._change_state(
|
|
original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
|
|
)
|
|
self._global_unblock_barrier.local_arrive()
|
|
|
|
def _handle_arrive_unblock_barrier(self):
|
|
logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
|
|
self._change_state(
|
|
original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
|
|
)
|
|
output_reqs = [*self._pending_reqs]
|
|
self._pending_reqs.clear()
|
|
return output_reqs
|
|
|
|
def _change_state(self, original: "_State", target: "_State"):
|
|
assert self._state == original, f"{self._state=} {original=} {target=}"
|
|
self._state = target
|
|
|
|
|
|
class _State(Enum):
|
|
UNBLOCKED = auto()
|
|
BLOCKED = auto()
|
|
GLOBAL_UNBLOCK_BARRIER = auto()
|
|
|
|
|
|
@contextmanager
|
|
def input_blocker_guard_region(send_to_scheduler):
|
|
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
|
|
try:
|
|
yield
|
|
finally:
|
|
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))
|