Refactor to add TypeBasedDispatcher to simplify dispatching (#2958)

This commit is contained in:
fzyzcjy
2025-01-19 12:13:27 +08:00
committed by GitHub
parent 4d4cdb3fe7
commit 81d27c8e31
3 changed files with 169 additions and 162 deletions

View File

@@ -97,7 +97,7 @@ from sglang.srt.utils import (
set_random_seed,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__)
@@ -422,6 +422,34 @@ class Scheduler:
},
)
self._dispatcher = TypeBasedDispatcher(
[
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReq, self.flush_cache_wrapped),
(AbortReq, self.abort_request),
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
(
UpdateWeightsFromDistributedReqInput,
self.update_weights_from_distributed,
),
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
(GetWeightsByNameReqInput, self.get_weights_by_name),
(ProfileReq, self.profile),
(OpenSessionReqInput, self.open_session),
(CloseSessionReqInput, self.close_session),
(
ReleaseMemoryOccupationReqInput,
lambda _: self.release_memory_occupation(),
),
(
ResumeMemoryOccupationReqInput,
lambda _: self.resume_memory_occupation(),
),
]
)
def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
self.watchdog_last_forward_ct = 0
@@ -563,57 +591,9 @@ class Scheduler:
def process_input_requests(self, recv_reqs: List):
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
self.handle_embedding_request(recv_req)
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
success, message = self.update_weights_from_disk(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightFromDiskReqOutput(success, message)
)
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
success, message = self.init_weights_update_group(recv_req)
self.send_to_tokenizer.send_pyobj(
InitWeightsUpdateGroupReqOutput(success, message)
)
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
success, message = self.update_weights_from_distributed(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromDistributedReqOutput(success, message)
)
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
success, message = self.update_weights_from_tensor(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromTensorReqOutput(success, message)
)
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
self.release_memory_occupation()
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
self.resume_memory_occupation()
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
elif isinstance(recv_req, OpenSessionReqInput):
session_id, success = self.open_session(recv_req)
self.send_to_tokenizer.send_pyobj(
OpenSessionReqOutput(session_id=session_id, success=success)
)
elif isinstance(recv_req, CloseSessionReqInput):
self.close_session(recv_req)
else:
raise ValueError(f"Invalid request: {recv_req}")
output = self._dispatcher(recv_req)
if output is not None:
self.send_to_tokenizer.send_pyobj(output)
def handle_generate_request(
self,
@@ -1545,6 +1525,9 @@ class Scheduler:
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
self.flush_cache()
def flush_cache(self):
"""Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and (
@@ -1597,12 +1580,12 @@ class Scheduler:
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message
return UpdateWeightFromDiskReqOutput(success, message)
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
return success, message
return InitWeightsUpdateGroupReqOutput(success, message)
def update_weights_from_distributed(
self,
@@ -1615,7 +1598,7 @@ class Scheduler:
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message
return UpdateWeightsFromDistributedReqOutput(success, message)
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
@@ -1626,11 +1609,11 @@ class Scheduler:
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message
return UpdateWeightsFromTensorReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter
return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self):
self.stashed_model_static_state = _export_static_state(
@@ -1638,6 +1621,7 @@ class Scheduler:
)
self.memory_saver_adapter.pause()
self.flush_cache()
return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self):
self.memory_saver_adapter.resume()
@@ -1645,6 +1629,13 @@ class Scheduler:
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
)
del self.stashed_model_static_state
return ResumeMemoryOccupationReqOutput()
def profile(self, recv_req: ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
def start_profile(self) -> None:
if self.profiler is None:
@@ -1660,20 +1651,20 @@ class Scheduler:
)
logger.info("Profiler is done")
def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
def open_session(self, recv_req: OpenSessionReqInput):
# handle error
session_id = recv_req.session_id
if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.")
return session_id, False
return OpenSessionReqOutput(session_id, False)
elif session_id is None:
logger.warning(f"session id is None, cannot open.")
return session_id, False
return OpenSessionReqOutput(session_id, False)
else:
self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id
)
return session_id, True
return OpenSessionReqOutput(session_id, True)
def close_session(self, recv_req: CloseSessionReqInput):
# handle error