From 81d27c8e31c26a435a062fbeaff66357d28a773c Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sun, 19 Jan 2025 12:13:27 +0800 Subject: [PATCH] Refactor to add TypeBasedDispatcher to simplify dispatching (#2958) --- python/sglang/srt/managers/scheduler.py | 113 +++++----- .../sglang/srt/managers/tokenizer_manager.py | 205 +++++++++--------- python/sglang/utils.py | 13 +- 3 files changed, 169 insertions(+), 162 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d62abaff9..d859a30a0 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -97,7 +97,7 @@ from sglang.srt.utils import ( set_random_seed, suppress_other_loggers, ) -from sglang.utils import get_exception_traceback +from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) @@ -422,6 +422,34 @@ class Scheduler: }, ) + self._dispatcher = TypeBasedDispatcher( + [ + (TokenizedGenerateReqInput, self.handle_generate_request), + (TokenizedEmbeddingReqInput, self.handle_embedding_request), + (FlushCacheReq, self.flush_cache_wrapped), + (AbortReq, self.abort_request), + (UpdateWeightFromDiskReqInput, self.update_weights_from_disk), + (InitWeightsUpdateGroupReqInput, self.init_weights_update_group), + ( + UpdateWeightsFromDistributedReqInput, + self.update_weights_from_distributed, + ), + (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), + (GetWeightsByNameReqInput, self.get_weights_by_name), + (ProfileReq, self.profile), + (OpenSessionReqInput, self.open_session), + (CloseSessionReqInput, self.close_session), + ( + ReleaseMemoryOccupationReqInput, + lambda _: self.release_memory_occupation(), + ), + ( + ResumeMemoryOccupationReqInput, + lambda _: self.resume_memory_occupation(), + ), + ] + ) + def watchdog_thread(self): """A watch dog thread that will try to kill the server itself if one batch takes too long.""" self.watchdog_last_forward_ct = 0 @@ -563,57 +591,9 @@ class Scheduler: def process_input_requests(self, recv_reqs: List): for recv_req in recv_reqs: - if isinstance(recv_req, TokenizedGenerateReqInput): - self.handle_generate_request(recv_req) - elif isinstance(recv_req, TokenizedEmbeddingReqInput): - self.handle_embedding_request(recv_req) - elif isinstance(recv_req, FlushCacheReq): - self.flush_cache() - elif isinstance(recv_req, AbortReq): - self.abort_request(recv_req) - elif isinstance(recv_req, UpdateWeightFromDiskReqInput): - success, message = self.update_weights_from_disk(recv_req) - self.send_to_tokenizer.send_pyobj( - UpdateWeightFromDiskReqOutput(success, message) - ) - elif isinstance(recv_req, InitWeightsUpdateGroupReqInput): - success, message = self.init_weights_update_group(recv_req) - self.send_to_tokenizer.send_pyobj( - InitWeightsUpdateGroupReqOutput(success, message) - ) - elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput): - success, message = self.update_weights_from_distributed(recv_req) - self.send_to_tokenizer.send_pyobj( - UpdateWeightsFromDistributedReqOutput(success, message) - ) - elif isinstance(recv_req, UpdateWeightsFromTensorReqInput): - success, message = self.update_weights_from_tensor(recv_req) - self.send_to_tokenizer.send_pyobj( - UpdateWeightsFromTensorReqOutput(success, message) - ) - elif isinstance(recv_req, GetWeightsByNameReqInput): - parameter = self.get_weights_by_name(recv_req) - self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) - elif isinstance(recv_req, ReleaseMemoryOccupationReqInput): - self.release_memory_occupation() - self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput()) - elif isinstance(recv_req, ResumeMemoryOccupationReqInput): - self.resume_memory_occupation() - self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput()) - elif isinstance(recv_req, ProfileReq): - if recv_req == ProfileReq.START_PROFILE: - self.start_profile() - else: - self.stop_profile() - elif isinstance(recv_req, OpenSessionReqInput): - session_id, success = self.open_session(recv_req) - self.send_to_tokenizer.send_pyobj( - OpenSessionReqOutput(session_id=session_id, success=success) - ) - elif isinstance(recv_req, CloseSessionReqInput): - self.close_session(recv_req) - else: - raise ValueError(f"Invalid request: {recv_req}") + output = self._dispatcher(recv_req) + if output is not None: + self.send_to_tokenizer.send_pyobj(output) def handle_generate_request( self, @@ -1545,6 +1525,9 @@ class Scheduler: self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs]) self.grammar_queue = self.grammar_queue[num_ready_reqs:] + def flush_cache_wrapped(self, recv_req: FlushCacheReq): + self.flush_cache() + def flush_cache(self): """Flush the memory pool and cache.""" if len(self.waiting_queue) == 0 and ( @@ -1597,12 +1580,12 @@ class Scheduler: assert flash_cache_success, "Cache flush failed after updating weights" else: logger.error(message) - return success, message + return UpdateWeightFromDiskReqOutput(success, message) def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): """Initialize the online model parameter update group.""" success, message = self.tp_worker.init_weights_update_group(recv_req) - return success, message + return InitWeightsUpdateGroupReqOutput(success, message) def update_weights_from_distributed( self, @@ -1615,7 +1598,7 @@ class Scheduler: assert flash_cache_success, "Cache flush failed after updating weights" else: logger.error(message) - return success, message + return UpdateWeightsFromDistributedReqOutput(success, message) def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): """Update the online model parameter from tensors.""" @@ -1626,11 +1609,11 @@ class Scheduler: assert flash_cache_success, "Cache flush failed after updating weights" else: logger.error(message) - return success, message + return UpdateWeightsFromTensorReqOutput(success, message) def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) - return parameter + return GetWeightsByNameReqOutput(parameter) def release_memory_occupation(self): self.stashed_model_static_state = _export_static_state( @@ -1638,6 +1621,7 @@ class Scheduler: ) self.memory_saver_adapter.pause() self.flush_cache() + return ReleaseMemoryOccupationReqOutput() def resume_memory_occupation(self): self.memory_saver_adapter.resume() @@ -1645,6 +1629,13 @@ class Scheduler: self.tp_worker.worker.model_runner.model, self.stashed_model_static_state ) del self.stashed_model_static_state + return ResumeMemoryOccupationReqOutput() + + def profile(self, recv_req: ProfileReq): + if recv_req == ProfileReq.START_PROFILE: + self.start_profile() + else: + self.stop_profile() def start_profile(self) -> None: if self.profiler is None: @@ -1660,20 +1651,20 @@ class Scheduler: ) logger.info("Profiler is done") - def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]: + def open_session(self, recv_req: OpenSessionReqInput): # handle error session_id = recv_req.session_id if session_id in self.sessions: logger.warning(f"session id {session_id} already exist, cannot open.") - return session_id, False + return OpenSessionReqOutput(session_id, False) elif session_id is None: logger.warning(f"session id is None, cannot open.") - return session_id, False + return OpenSessionReqOutput(session_id, False) else: self.sessions[session_id] = Session( recv_req.capacity_of_str_len, session_id ) - return session_id, True + return OpenSessionReqOutput(session_id, True) def close_session(self, recv_req: CloseSessionReqInput): # handle error diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 85dcbcbd0..74f46538c 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -80,7 +80,7 @@ from sglang.srt.utils import ( get_zmq_socket, kill_process_tree, ) -from sglang.utils import get_exception_traceback +from sglang.utils import TypeBasedDispatcher, get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -221,6 +221,43 @@ class TokenizerManager: }, ) + self._dispatcher = TypeBasedDispatcher( + [ + (BatchStrOut, self._handle_batch_output), + (BatchEmbeddingOut, self._handle_batch_output), + (BatchTokenIDOut, self._handle_batch_output), + (OpenSessionReqOutput, self._handle_open_session_req_output), + ( + UpdateWeightFromDiskReqOutput, + self._handle_update_weights_from_disk_req_output, + ), + ( + InitWeightsUpdateGroupReqOutput, + self.init_weights_update_group_communicator.handle_recv, + ), + ( + UpdateWeightsFromDistributedReqOutput, + self.update_weights_from_distributed_communicator.handle_recv, + ), + ( + UpdateWeightsFromTensorReqOutput, + self.update_weights_from_tensor_communicator.handle_recv, + ), + ( + GetWeightsByNameReqOutput, + self.get_weights_by_name_communicator.handle_recv, + ), + ( + ReleaseMemoryOccupationReqOutput, + self.release_memory_occupation_communicator.handle_recv, + ), + ( + ResumeMemoryOccupationReqOutput, + self.resume_memory_occupation_communicator.handle_recv, + ), + ] + ) + async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -712,110 +749,64 @@ class TokenizerManager: """The event loop that handles requests""" while True: - recv_obj: Union[ - BatchStrOut, - BatchEmbeddingOut, - BatchTokenIDOut, - UpdateWeightFromDiskReqOutput, - UpdateWeightsFromDistributedReqOutput, - GetWeightsByNameReqOutput, - InitWeightsUpdateGroupReqOutput, - ReleaseMemoryOccupationReqOutput, - ResumeMemoryOccupationReqOutput, - ] = await self.recv_from_detokenizer.recv_pyobj() + recv_obj = await self.recv_from_detokenizer.recv_pyobj() + self._dispatcher(recv_obj) - if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)): - for i, rid in enumerate(recv_obj.rids): - state = self.rid_to_state.get(rid, None) - if state is None: - continue + def _handle_batch_output( + self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] + ): + for i, rid in enumerate(recv_obj.rids): + state = self.rid_to_state.get(rid, None) + if state is None: + continue - meta_info = { - "id": rid, - "finish_reason": recv_obj.finished_reasons[i], - "prompt_tokens": recv_obj.prompt_tokens[i], - } + meta_info = { + "id": rid, + "finish_reason": recv_obj.finished_reasons[i], + "prompt_tokens": recv_obj.prompt_tokens[i], + } - if getattr(state.obj, "return_logprob", False): - self.convert_logprob_style( - meta_info, - state.obj.top_logprobs_num, - state.obj.return_text_in_logprobs, - recv_obj, - i, - ) - - if not isinstance(recv_obj, BatchEmbeddingOut): - meta_info.update( - { - "completion_tokens": recv_obj.completion_tokens[i], - "cached_tokens": recv_obj.cached_tokens[i], - } - ) - - if isinstance(recv_obj, BatchStrOut): - out_dict = { - "text": recv_obj.output_strs[i], - "meta_info": meta_info, - } - elif isinstance(recv_obj, BatchTokenIDOut): - out_dict = { - "token_ids": recv_obj.output_ids[i], - "meta_info": meta_info, - } - else: - assert isinstance(recv_obj, BatchEmbeddingOut) - out_dict = { - "embedding": recv_obj.embeddings[i], - "meta_info": meta_info, - } - state.out_list.append(out_dict) - state.finished = recv_obj.finished_reasons[i] is not None - state.event.set() - - if self.enable_metrics and state.obj.log_metrics: - self.collect_metrics(state, recv_obj, i) - if ( - self.dump_requests_folder - and state.finished - and state.obj.log_metrics - ): - self.dump_requests(state, out_dict) - elif isinstance(recv_obj, OpenSessionReqOutput): - self.session_futures[recv_obj.session_id].set_result( - recv_obj.session_id if recv_obj.success else None + if getattr(state.obj, "return_logprob", False): + self.convert_logprob_style( + meta_info, + state.obj.top_logprobs_num, + state.obj.return_text_in_logprobs, + recv_obj, + i, ) - elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput): - if self.server_args.dp_size == 1: - self.model_update_result.set_result(recv_obj) - else: # self.server_args.dp_size > 1 - self.model_update_tmp.append(recv_obj) - # set future if the all results are recevied - if len(self.model_update_tmp) == self.server_args.dp_size: - self.model_update_result.set_result(self.model_update_tmp) - elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput): - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" - self.init_weights_update_group_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput): - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for update weights from distributed" - self.update_weights_from_distributed_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput): - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for update weights from distributed" - self.update_weights_from_tensor_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, GetWeightsByNameReqOutput): - self.get_weights_by_name_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput): - self.release_memory_occupation_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput): - self.resume_memory_occupation_communicator.handle_recv(recv_obj) + + if not isinstance(recv_obj, BatchEmbeddingOut): + meta_info.update( + { + "completion_tokens": recv_obj.completion_tokens[i], + "cached_tokens": recv_obj.cached_tokens[i], + } + ) + + if isinstance(recv_obj, BatchStrOut): + out_dict = { + "text": recv_obj.output_strs[i], + "meta_info": meta_info, + } + elif isinstance(recv_obj, BatchTokenIDOut): + out_dict = { + "token_ids": recv_obj.output_ids[i], + "meta_info": meta_info, + } else: - raise ValueError(f"Invalid object: {recv_obj=}") + assert isinstance(recv_obj, BatchEmbeddingOut) + out_dict = { + "embedding": recv_obj.embeddings[i], + "meta_info": meta_info, + } + state.out_list.append(out_dict) + state.finished = recv_obj.finished_reasons[i] is not None + state.event.set() + + if self.enable_metrics and state.obj.log_metrics: + self.collect_metrics(state, recv_obj, i) + if self.dump_requests_folder and state.finished and state.obj.log_metrics: + self.dump_requests(state, out_dict) def convert_logprob_style( self, @@ -943,6 +934,20 @@ class TokenizerManager: # Schedule the task to run in the background without awaiting it asyncio.create_task(asyncio.to_thread(background_task)) + def _handle_open_session_req_output(self, recv_obj): + self.session_futures[recv_obj.session_id].set_result( + recv_obj.session_id if recv_obj.success else None + ) + + def _handle_update_weights_from_disk_req_output(self, recv_obj): + if self.server_args.dp_size == 1: + self.model_update_result.set_result(recv_obj) + else: # self.server_args.dp_size > 1 + self.model_update_tmp.append(recv_obj) + # set future if the all results are recevied + if len(self.model_update_tmp) == self.server_args.dp_size: + self.model_update_result.set_result(self.model_update_tmp) + async def print_exception_wrapper(func): """ diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 98e0f3f4f..98942fbb3 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -15,7 +15,7 @@ import urllib.request from concurrent.futures import ThreadPoolExecutor from io import BytesIO from json import dumps -from typing import Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import numpy as np import requests @@ -363,3 +363,14 @@ def terminate_process(process): def print_highlight(html_content: str): html_content = str(html_content).replace("\n", "
") display(HTML(f"{html_content}")) + + +class TypeBasedDispatcher: + def __init__(self, mapping: List[Tuple[Type, Callable]]): + self._mapping = mapping + + def __call__(self, obj: Any): + for ty, fn in self._mapping: + if isinstance(obj, ty): + return fn(obj) + raise ValueError(f"Invalid object: {obj}")