From 81d27c8e31c26a435a062fbeaff66357d28a773c Mon Sep 17 00:00:00 2001
From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
Date: Sun, 19 Jan 2025 12:13:27 +0800
Subject: [PATCH] Refactor to add TypeBasedDispatcher to simplify dispatching
(#2958)
---
python/sglang/srt/managers/scheduler.py | 113 +++++-----
.../sglang/srt/managers/tokenizer_manager.py | 205 +++++++++---------
python/sglang/utils.py | 13 +-
3 files changed, 169 insertions(+), 162 deletions(-)
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index d62abaff9..d859a30a0 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -97,7 +97,7 @@ from sglang.srt.utils import (
set_random_seed,
suppress_other_loggers,
)
-from sglang.utils import get_exception_traceback
+from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__)
@@ -422,6 +422,34 @@ class Scheduler:
},
)
+ self._dispatcher = TypeBasedDispatcher(
+ [
+ (TokenizedGenerateReqInput, self.handle_generate_request),
+ (TokenizedEmbeddingReqInput, self.handle_embedding_request),
+ (FlushCacheReq, self.flush_cache_wrapped),
+ (AbortReq, self.abort_request),
+ (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
+ (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
+ (
+ UpdateWeightsFromDistributedReqInput,
+ self.update_weights_from_distributed,
+ ),
+ (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
+ (GetWeightsByNameReqInput, self.get_weights_by_name),
+ (ProfileReq, self.profile),
+ (OpenSessionReqInput, self.open_session),
+ (CloseSessionReqInput, self.close_session),
+ (
+ ReleaseMemoryOccupationReqInput,
+ lambda _: self.release_memory_occupation(),
+ ),
+ (
+ ResumeMemoryOccupationReqInput,
+ lambda _: self.resume_memory_occupation(),
+ ),
+ ]
+ )
+
def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
self.watchdog_last_forward_ct = 0
@@ -563,57 +591,9 @@ class Scheduler:
def process_input_requests(self, recv_reqs: List):
for recv_req in recv_reqs:
- if isinstance(recv_req, TokenizedGenerateReqInput):
- self.handle_generate_request(recv_req)
- elif isinstance(recv_req, TokenizedEmbeddingReqInput):
- self.handle_embedding_request(recv_req)
- elif isinstance(recv_req, FlushCacheReq):
- self.flush_cache()
- elif isinstance(recv_req, AbortReq):
- self.abort_request(recv_req)
- elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
- success, message = self.update_weights_from_disk(recv_req)
- self.send_to_tokenizer.send_pyobj(
- UpdateWeightFromDiskReqOutput(success, message)
- )
- elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
- success, message = self.init_weights_update_group(recv_req)
- self.send_to_tokenizer.send_pyobj(
- InitWeightsUpdateGroupReqOutput(success, message)
- )
- elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
- success, message = self.update_weights_from_distributed(recv_req)
- self.send_to_tokenizer.send_pyobj(
- UpdateWeightsFromDistributedReqOutput(success, message)
- )
- elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
- success, message = self.update_weights_from_tensor(recv_req)
- self.send_to_tokenizer.send_pyobj(
- UpdateWeightsFromTensorReqOutput(success, message)
- )
- elif isinstance(recv_req, GetWeightsByNameReqInput):
- parameter = self.get_weights_by_name(recv_req)
- self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
- elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
- self.release_memory_occupation()
- self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
- elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
- self.resume_memory_occupation()
- self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
- elif isinstance(recv_req, ProfileReq):
- if recv_req == ProfileReq.START_PROFILE:
- self.start_profile()
- else:
- self.stop_profile()
- elif isinstance(recv_req, OpenSessionReqInput):
- session_id, success = self.open_session(recv_req)
- self.send_to_tokenizer.send_pyobj(
- OpenSessionReqOutput(session_id=session_id, success=success)
- )
- elif isinstance(recv_req, CloseSessionReqInput):
- self.close_session(recv_req)
- else:
- raise ValueError(f"Invalid request: {recv_req}")
+ output = self._dispatcher(recv_req)
+ if output is not None:
+ self.send_to_tokenizer.send_pyobj(output)
def handle_generate_request(
self,
@@ -1545,6 +1525,9 @@ class Scheduler:
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
+ def flush_cache_wrapped(self, recv_req: FlushCacheReq):
+ self.flush_cache()
+
def flush_cache(self):
"""Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and (
@@ -1597,12 +1580,12 @@ class Scheduler:
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
- return success, message
+ return UpdateWeightFromDiskReqOutput(success, message)
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
- return success, message
+ return InitWeightsUpdateGroupReqOutput(success, message)
def update_weights_from_distributed(
self,
@@ -1615,7 +1598,7 @@ class Scheduler:
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
- return success, message
+ return UpdateWeightsFromDistributedReqOutput(success, message)
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
@@ -1626,11 +1609,11 @@ class Scheduler:
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
- return success, message
+ return UpdateWeightsFromTensorReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
- return parameter
+ return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self):
self.stashed_model_static_state = _export_static_state(
@@ -1638,6 +1621,7 @@ class Scheduler:
)
self.memory_saver_adapter.pause()
self.flush_cache()
+ return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self):
self.memory_saver_adapter.resume()
@@ -1645,6 +1629,13 @@ class Scheduler:
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
)
del self.stashed_model_static_state
+ return ResumeMemoryOccupationReqOutput()
+
+ def profile(self, recv_req: ProfileReq):
+ if recv_req == ProfileReq.START_PROFILE:
+ self.start_profile()
+ else:
+ self.stop_profile()
def start_profile(self) -> None:
if self.profiler is None:
@@ -1660,20 +1651,20 @@ class Scheduler:
)
logger.info("Profiler is done")
- def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
+ def open_session(self, recv_req: OpenSessionReqInput):
# handle error
session_id = recv_req.session_id
if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.")
- return session_id, False
+ return OpenSessionReqOutput(session_id, False)
elif session_id is None:
logger.warning(f"session id is None, cannot open.")
- return session_id, False
+ return OpenSessionReqOutput(session_id, False)
else:
self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id
)
- return session_id, True
+ return OpenSessionReqOutput(session_id, True)
def close_session(self, recv_req: CloseSessionReqInput):
# handle error
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 85dcbcbd0..74f46538c 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -80,7 +80,7 @@ from sglang.srt.utils import (
get_zmq_socket,
kill_process_tree,
)
-from sglang.utils import get_exception_traceback
+from sglang.utils import TypeBasedDispatcher, get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -221,6 +221,43 @@ class TokenizerManager:
},
)
+ self._dispatcher = TypeBasedDispatcher(
+ [
+ (BatchStrOut, self._handle_batch_output),
+ (BatchEmbeddingOut, self._handle_batch_output),
+ (BatchTokenIDOut, self._handle_batch_output),
+ (OpenSessionReqOutput, self._handle_open_session_req_output),
+ (
+ UpdateWeightFromDiskReqOutput,
+ self._handle_update_weights_from_disk_req_output,
+ ),
+ (
+ InitWeightsUpdateGroupReqOutput,
+ self.init_weights_update_group_communicator.handle_recv,
+ ),
+ (
+ UpdateWeightsFromDistributedReqOutput,
+ self.update_weights_from_distributed_communicator.handle_recv,
+ ),
+ (
+ UpdateWeightsFromTensorReqOutput,
+ self.update_weights_from_tensor_communicator.handle_recv,
+ ),
+ (
+ GetWeightsByNameReqOutput,
+ self.get_weights_by_name_communicator.handle_recv,
+ ),
+ (
+ ReleaseMemoryOccupationReqOutput,
+ self.release_memory_occupation_communicator.handle_recv,
+ ),
+ (
+ ResumeMemoryOccupationReqOutput,
+ self.resume_memory_occupation_communicator.handle_recv,
+ ),
+ ]
+ )
+
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -712,110 +749,64 @@ class TokenizerManager:
"""The event loop that handles requests"""
while True:
- recv_obj: Union[
- BatchStrOut,
- BatchEmbeddingOut,
- BatchTokenIDOut,
- UpdateWeightFromDiskReqOutput,
- UpdateWeightsFromDistributedReqOutput,
- GetWeightsByNameReqOutput,
- InitWeightsUpdateGroupReqOutput,
- ReleaseMemoryOccupationReqOutput,
- ResumeMemoryOccupationReqOutput,
- ] = await self.recv_from_detokenizer.recv_pyobj()
+ recv_obj = await self.recv_from_detokenizer.recv_pyobj()
+ self._dispatcher(recv_obj)
- if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
- for i, rid in enumerate(recv_obj.rids):
- state = self.rid_to_state.get(rid, None)
- if state is None:
- continue
+ def _handle_batch_output(
+ self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
+ ):
+ for i, rid in enumerate(recv_obj.rids):
+ state = self.rid_to_state.get(rid, None)
+ if state is None:
+ continue
- meta_info = {
- "id": rid,
- "finish_reason": recv_obj.finished_reasons[i],
- "prompt_tokens": recv_obj.prompt_tokens[i],
- }
+ meta_info = {
+ "id": rid,
+ "finish_reason": recv_obj.finished_reasons[i],
+ "prompt_tokens": recv_obj.prompt_tokens[i],
+ }
- if getattr(state.obj, "return_logprob", False):
- self.convert_logprob_style(
- meta_info,
- state.obj.top_logprobs_num,
- state.obj.return_text_in_logprobs,
- recv_obj,
- i,
- )
-
- if not isinstance(recv_obj, BatchEmbeddingOut):
- meta_info.update(
- {
- "completion_tokens": recv_obj.completion_tokens[i],
- "cached_tokens": recv_obj.cached_tokens[i],
- }
- )
-
- if isinstance(recv_obj, BatchStrOut):
- out_dict = {
- "text": recv_obj.output_strs[i],
- "meta_info": meta_info,
- }
- elif isinstance(recv_obj, BatchTokenIDOut):
- out_dict = {
- "token_ids": recv_obj.output_ids[i],
- "meta_info": meta_info,
- }
- else:
- assert isinstance(recv_obj, BatchEmbeddingOut)
- out_dict = {
- "embedding": recv_obj.embeddings[i],
- "meta_info": meta_info,
- }
- state.out_list.append(out_dict)
- state.finished = recv_obj.finished_reasons[i] is not None
- state.event.set()
-
- if self.enable_metrics and state.obj.log_metrics:
- self.collect_metrics(state, recv_obj, i)
- if (
- self.dump_requests_folder
- and state.finished
- and state.obj.log_metrics
- ):
- self.dump_requests(state, out_dict)
- elif isinstance(recv_obj, OpenSessionReqOutput):
- self.session_futures[recv_obj.session_id].set_result(
- recv_obj.session_id if recv_obj.success else None
+ if getattr(state.obj, "return_logprob", False):
+ self.convert_logprob_style(
+ meta_info,
+ state.obj.top_logprobs_num,
+ state.obj.return_text_in_logprobs,
+ recv_obj,
+ i,
)
- elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
- if self.server_args.dp_size == 1:
- self.model_update_result.set_result(recv_obj)
- else: # self.server_args.dp_size > 1
- self.model_update_tmp.append(recv_obj)
- # set future if the all results are recevied
- if len(self.model_update_tmp) == self.server_args.dp_size:
- self.model_update_result.set_result(self.model_update_tmp)
- elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
- assert (
- self.server_args.dp_size == 1
- ), "dp_size must be 1 for init parameter update group"
- self.init_weights_update_group_communicator.handle_recv(recv_obj)
- elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
- assert (
- self.server_args.dp_size == 1
- ), "dp_size must be 1 for update weights from distributed"
- self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
- elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
- assert (
- self.server_args.dp_size == 1
- ), "dp_size must be 1 for update weights from distributed"
- self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
- elif isinstance(recv_obj, GetWeightsByNameReqOutput):
- self.get_weights_by_name_communicator.handle_recv(recv_obj)
- elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
- self.release_memory_occupation_communicator.handle_recv(recv_obj)
- elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
- self.resume_memory_occupation_communicator.handle_recv(recv_obj)
+
+ if not isinstance(recv_obj, BatchEmbeddingOut):
+ meta_info.update(
+ {
+ "completion_tokens": recv_obj.completion_tokens[i],
+ "cached_tokens": recv_obj.cached_tokens[i],
+ }
+ )
+
+ if isinstance(recv_obj, BatchStrOut):
+ out_dict = {
+ "text": recv_obj.output_strs[i],
+ "meta_info": meta_info,
+ }
+ elif isinstance(recv_obj, BatchTokenIDOut):
+ out_dict = {
+ "token_ids": recv_obj.output_ids[i],
+ "meta_info": meta_info,
+ }
else:
- raise ValueError(f"Invalid object: {recv_obj=}")
+ assert isinstance(recv_obj, BatchEmbeddingOut)
+ out_dict = {
+ "embedding": recv_obj.embeddings[i],
+ "meta_info": meta_info,
+ }
+ state.out_list.append(out_dict)
+ state.finished = recv_obj.finished_reasons[i] is not None
+ state.event.set()
+
+ if self.enable_metrics and state.obj.log_metrics:
+ self.collect_metrics(state, recv_obj, i)
+ if self.dump_requests_folder and state.finished and state.obj.log_metrics:
+ self.dump_requests(state, out_dict)
def convert_logprob_style(
self,
@@ -943,6 +934,20 @@ class TokenizerManager:
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
+ def _handle_open_session_req_output(self, recv_obj):
+ self.session_futures[recv_obj.session_id].set_result(
+ recv_obj.session_id if recv_obj.success else None
+ )
+
+ def _handle_update_weights_from_disk_req_output(self, recv_obj):
+ if self.server_args.dp_size == 1:
+ self.model_update_result.set_result(recv_obj)
+ else: # self.server_args.dp_size > 1
+ self.model_update_tmp.append(recv_obj)
+ # set future if the all results are recevied
+ if len(self.model_update_tmp) == self.server_args.dp_size:
+ self.model_update_result.set_result(self.model_update_tmp)
+
async def print_exception_wrapper(func):
"""
diff --git a/python/sglang/utils.py b/python/sglang/utils.py
index 98e0f3f4f..98942fbb3 100644
--- a/python/sglang/utils.py
+++ b/python/sglang/utils.py
@@ -15,7 +15,7 @@ import urllib.request
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from json import dumps
-from typing import Optional, Union
+from typing import Any, Callable, List, Optional, Tuple, Type, Union
import numpy as np
import requests
@@ -363,3 +363,14 @@ def terminate_process(process):
def print_highlight(html_content: str):
html_content = str(html_content).replace("\n", "
")
display(HTML(f"{html_content}"))
+
+
+class TypeBasedDispatcher:
+ def __init__(self, mapping: List[Tuple[Type, Callable]]):
+ self._mapping = mapping
+
+ def __call__(self, obj: Any):
+ for ty, fn in self._mapping:
+ if isinstance(obj, ty):
+ return fn(obj)
+ raise ValueError(f"Invalid object: {obj}")