Refactor to add TypeBasedDispatcher to simplify dispatching (#2958)
This commit is contained in:
@@ -97,7 +97,7 @@ from sglang.srt.utils import (
|
||||
set_random_seed,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -422,6 +422,34 @@ class Scheduler:
|
||||
},
|
||||
)
|
||||
|
||||
self._dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(TokenizedGenerateReqInput, self.handle_generate_request),
|
||||
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
||||
(FlushCacheReq, self.flush_cache_wrapped),
|
||||
(AbortReq, self.abort_request),
|
||||
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
||||
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
||||
(
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
self.update_weights_from_distributed,
|
||||
),
|
||||
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
||||
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
||||
(ProfileReq, self.profile),
|
||||
(OpenSessionReqInput, self.open_session),
|
||||
(CloseSessionReqInput, self.close_session),
|
||||
(
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
lambda _: self.release_memory_occupation(),
|
||||
),
|
||||
(
|
||||
ResumeMemoryOccupationReqInput,
|
||||
lambda _: self.resume_memory_occupation(),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def watchdog_thread(self):
|
||||
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
||||
self.watchdog_last_forward_ct = 0
|
||||
@@ -563,57 +591,9 @@ class Scheduler:
|
||||
|
||||
def process_input_requests(self, recv_reqs: List):
|
||||
for recv_req in recv_reqs:
|
||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
self.handle_generate_request(recv_req)
|
||||
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
|
||||
self.handle_embedding_request(recv_req)
|
||||
elif isinstance(recv_req, FlushCacheReq):
|
||||
self.flush_cache()
|
||||
elif isinstance(recv_req, AbortReq):
|
||||
self.abort_request(recv_req)
|
||||
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
|
||||
success, message = self.update_weights_from_disk(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightFromDiskReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
||||
success, message = self.init_weights_update_group(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
InitWeightsUpdateGroupReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
|
||||
success, message = self.update_weights_from_distributed(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightsFromDistributedReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
|
||||
success, message = self.update_weights_from_tensor(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightsFromTensorReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||
parameter = self.get_weights_by_name(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
|
||||
self.release_memory_occupation()
|
||||
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
|
||||
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
|
||||
self.resume_memory_occupation()
|
||||
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
else:
|
||||
self.stop_profile()
|
||||
elif isinstance(recv_req, OpenSessionReqInput):
|
||||
session_id, success = self.open_session(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
OpenSessionReqOutput(session_id=session_id, success=success)
|
||||
)
|
||||
elif isinstance(recv_req, CloseSessionReqInput):
|
||||
self.close_session(recv_req)
|
||||
else:
|
||||
raise ValueError(f"Invalid request: {recv_req}")
|
||||
output = self._dispatcher(recv_req)
|
||||
if output is not None:
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
def handle_generate_request(
|
||||
self,
|
||||
@@ -1545,6 +1525,9 @@ class Scheduler:
|
||||
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
|
||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||
|
||||
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
|
||||
self.flush_cache()
|
||||
|
||||
def flush_cache(self):
|
||||
"""Flush the memory pool and cache."""
|
||||
if len(self.waiting_queue) == 0 and (
|
||||
@@ -1597,12 +1580,12 @@ class Scheduler:
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
return UpdateWeightFromDiskReqOutput(success, message)
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
"""Initialize the online model parameter update group."""
|
||||
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||
return success, message
|
||||
return InitWeightsUpdateGroupReqOutput(success, message)
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self,
|
||||
@@ -1615,7 +1598,7 @@ class Scheduler:
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
return UpdateWeightsFromDistributedReqOutput(success, message)
|
||||
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
"""Update the online model parameter from tensors."""
|
||||
@@ -1626,11 +1609,11 @@ class Scheduler:
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
return UpdateWeightsFromTensorReqOutput(success, message)
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return parameter
|
||||
return GetWeightsByNameReqOutput(parameter)
|
||||
|
||||
def release_memory_occupation(self):
|
||||
self.stashed_model_static_state = _export_static_state(
|
||||
@@ -1638,6 +1621,7 @@ class Scheduler:
|
||||
)
|
||||
self.memory_saver_adapter.pause()
|
||||
self.flush_cache()
|
||||
return ReleaseMemoryOccupationReqOutput()
|
||||
|
||||
def resume_memory_occupation(self):
|
||||
self.memory_saver_adapter.resume()
|
||||
@@ -1645,6 +1629,13 @@ class Scheduler:
|
||||
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
||||
)
|
||||
del self.stashed_model_static_state
|
||||
return ResumeMemoryOccupationReqOutput()
|
||||
|
||||
def profile(self, recv_req: ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
else:
|
||||
self.stop_profile()
|
||||
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
@@ -1660,20 +1651,20 @@ class Scheduler:
|
||||
)
|
||||
logger.info("Profiler is done")
|
||||
|
||||
def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
|
||||
def open_session(self, recv_req: OpenSessionReqInput):
|
||||
# handle error
|
||||
session_id = recv_req.session_id
|
||||
if session_id in self.sessions:
|
||||
logger.warning(f"session id {session_id} already exist, cannot open.")
|
||||
return session_id, False
|
||||
return OpenSessionReqOutput(session_id, False)
|
||||
elif session_id is None:
|
||||
logger.warning(f"session id is None, cannot open.")
|
||||
return session_id, False
|
||||
return OpenSessionReqOutput(session_id, False)
|
||||
else:
|
||||
self.sessions[session_id] = Session(
|
||||
recv_req.capacity_of_str_len, session_id
|
||||
)
|
||||
return session_id, True
|
||||
return OpenSessionReqOutput(session_id, True)
|
||||
|
||||
def close_session(self, recv_req: CloseSessionReqInput):
|
||||
# handle error
|
||||
|
||||
Reference in New Issue
Block a user