Refactor to add TypeBasedDispatcher to simplify dispatching (#2958)
This commit is contained in:
@@ -97,7 +97,7 @@ from sglang.srt.utils import (
|
||||
set_random_seed,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -422,6 +422,34 @@ class Scheduler:
|
||||
},
|
||||
)
|
||||
|
||||
self._dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(TokenizedGenerateReqInput, self.handle_generate_request),
|
||||
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
||||
(FlushCacheReq, self.flush_cache_wrapped),
|
||||
(AbortReq, self.abort_request),
|
||||
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
||||
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
||||
(
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
self.update_weights_from_distributed,
|
||||
),
|
||||
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
||||
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
||||
(ProfileReq, self.profile),
|
||||
(OpenSessionReqInput, self.open_session),
|
||||
(CloseSessionReqInput, self.close_session),
|
||||
(
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
lambda _: self.release_memory_occupation(),
|
||||
),
|
||||
(
|
||||
ResumeMemoryOccupationReqInput,
|
||||
lambda _: self.resume_memory_occupation(),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def watchdog_thread(self):
|
||||
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
||||
self.watchdog_last_forward_ct = 0
|
||||
@@ -563,57 +591,9 @@ class Scheduler:
|
||||
|
||||
def process_input_requests(self, recv_reqs: List):
|
||||
for recv_req in recv_reqs:
|
||||
if isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
self.handle_generate_request(recv_req)
|
||||
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
|
||||
self.handle_embedding_request(recv_req)
|
||||
elif isinstance(recv_req, FlushCacheReq):
|
||||
self.flush_cache()
|
||||
elif isinstance(recv_req, AbortReq):
|
||||
self.abort_request(recv_req)
|
||||
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
|
||||
success, message = self.update_weights_from_disk(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightFromDiskReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
||||
success, message = self.init_weights_update_group(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
InitWeightsUpdateGroupReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
|
||||
success, message = self.update_weights_from_distributed(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightsFromDistributedReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
|
||||
success, message = self.update_weights_from_tensor(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
UpdateWeightsFromTensorReqOutput(success, message)
|
||||
)
|
||||
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
||||
parameter = self.get_weights_by_name(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
||||
elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
|
||||
self.release_memory_occupation()
|
||||
self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
|
||||
elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
|
||||
self.resume_memory_occupation()
|
||||
self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
|
||||
elif isinstance(recv_req, ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
else:
|
||||
self.stop_profile()
|
||||
elif isinstance(recv_req, OpenSessionReqInput):
|
||||
session_id, success = self.open_session(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
OpenSessionReqOutput(session_id=session_id, success=success)
|
||||
)
|
||||
elif isinstance(recv_req, CloseSessionReqInput):
|
||||
self.close_session(recv_req)
|
||||
else:
|
||||
raise ValueError(f"Invalid request: {recv_req}")
|
||||
output = self._dispatcher(recv_req)
|
||||
if output is not None:
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
def handle_generate_request(
|
||||
self,
|
||||
@@ -1545,6 +1525,9 @@ class Scheduler:
|
||||
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
|
||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||
|
||||
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
|
||||
self.flush_cache()
|
||||
|
||||
def flush_cache(self):
|
||||
"""Flush the memory pool and cache."""
|
||||
if len(self.waiting_queue) == 0 and (
|
||||
@@ -1597,12 +1580,12 @@ class Scheduler:
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
return UpdateWeightFromDiskReqOutput(success, message)
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
"""Initialize the online model parameter update group."""
|
||||
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||
return success, message
|
||||
return InitWeightsUpdateGroupReqOutput(success, message)
|
||||
|
||||
def update_weights_from_distributed(
|
||||
self,
|
||||
@@ -1615,7 +1598,7 @@ class Scheduler:
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
return UpdateWeightsFromDistributedReqOutput(success, message)
|
||||
|
||||
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
||||
"""Update the online model parameter from tensors."""
|
||||
@@ -1626,11 +1609,11 @@ class Scheduler:
|
||||
assert flash_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
return UpdateWeightsFromTensorReqOutput(success, message)
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return parameter
|
||||
return GetWeightsByNameReqOutput(parameter)
|
||||
|
||||
def release_memory_occupation(self):
|
||||
self.stashed_model_static_state = _export_static_state(
|
||||
@@ -1638,6 +1621,7 @@ class Scheduler:
|
||||
)
|
||||
self.memory_saver_adapter.pause()
|
||||
self.flush_cache()
|
||||
return ReleaseMemoryOccupationReqOutput()
|
||||
|
||||
def resume_memory_occupation(self):
|
||||
self.memory_saver_adapter.resume()
|
||||
@@ -1645,6 +1629,13 @@ class Scheduler:
|
||||
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
|
||||
)
|
||||
del self.stashed_model_static_state
|
||||
return ResumeMemoryOccupationReqOutput()
|
||||
|
||||
def profile(self, recv_req: ProfileReq):
|
||||
if recv_req == ProfileReq.START_PROFILE:
|
||||
self.start_profile()
|
||||
else:
|
||||
self.stop_profile()
|
||||
|
||||
def start_profile(self) -> None:
|
||||
if self.profiler is None:
|
||||
@@ -1660,20 +1651,20 @@ class Scheduler:
|
||||
)
|
||||
logger.info("Profiler is done")
|
||||
|
||||
def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
|
||||
def open_session(self, recv_req: OpenSessionReqInput):
|
||||
# handle error
|
||||
session_id = recv_req.session_id
|
||||
if session_id in self.sessions:
|
||||
logger.warning(f"session id {session_id} already exist, cannot open.")
|
||||
return session_id, False
|
||||
return OpenSessionReqOutput(session_id, False)
|
||||
elif session_id is None:
|
||||
logger.warning(f"session id is None, cannot open.")
|
||||
return session_id, False
|
||||
return OpenSessionReqOutput(session_id, False)
|
||||
else:
|
||||
self.sessions[session_id] = Session(
|
||||
recv_req.capacity_of_str_len, session_id
|
||||
)
|
||||
return session_id, True
|
||||
return OpenSessionReqOutput(session_id, True)
|
||||
|
||||
def close_session(self, recv_req: CloseSessionReqInput):
|
||||
# handle error
|
||||
|
||||
@@ -80,7 +80,7 @@ from sglang.srt.utils import (
|
||||
get_zmq_socket,
|
||||
kill_process_tree,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -221,6 +221,43 @@ class TokenizerManager:
|
||||
},
|
||||
)
|
||||
|
||||
self._dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
(BatchStrOut, self._handle_batch_output),
|
||||
(BatchEmbeddingOut, self._handle_batch_output),
|
||||
(BatchTokenIDOut, self._handle_batch_output),
|
||||
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
||||
(
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
self._handle_update_weights_from_disk_req_output,
|
||||
),
|
||||
(
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
self.init_weights_update_group_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
self.update_weights_from_distributed_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
self.update_weights_from_tensor_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
GetWeightsByNameReqOutput,
|
||||
self.get_weights_by_name_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
self.release_memory_occupation_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
self.resume_memory_occupation_communicator.handle_recv,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
async def generate_request(
|
||||
self,
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
@@ -712,110 +749,64 @@ class TokenizerManager:
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj: Union[
|
||||
BatchStrOut,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
GetWeightsByNameReqOutput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
] = await self.recv_from_detokenizer.recv_pyobj()
|
||||
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
||||
self._dispatcher(recv_obj)
|
||||
|
||||
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
continue
|
||||
def _handle_batch_output(
|
||||
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
|
||||
):
|
||||
for i, rid in enumerate(recv_obj.rids):
|
||||
state = self.rid_to_state.get(rid, None)
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
meta_info = {
|
||||
"id": rid,
|
||||
"finish_reason": recv_obj.finished_reasons[i],
|
||||
"prompt_tokens": recv_obj.prompt_tokens[i],
|
||||
}
|
||||
meta_info = {
|
||||
"id": rid,
|
||||
"finish_reason": recv_obj.finished_reasons[i],
|
||||
"prompt_tokens": recv_obj.prompt_tokens[i],
|
||||
}
|
||||
|
||||
if getattr(state.obj, "return_logprob", False):
|
||||
self.convert_logprob_style(
|
||||
meta_info,
|
||||
state.obj.top_logprobs_num,
|
||||
state.obj.return_text_in_logprobs,
|
||||
recv_obj,
|
||||
i,
|
||||
)
|
||||
|
||||
if not isinstance(recv_obj, BatchEmbeddingOut):
|
||||
meta_info.update(
|
||||
{
|
||||
"completion_tokens": recv_obj.completion_tokens[i],
|
||||
"cached_tokens": recv_obj.cached_tokens[i],
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
"embedding": recv_obj.embeddings[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reasons[i] is not None
|
||||
state.event.set()
|
||||
|
||||
if self.enable_metrics and state.obj.log_metrics:
|
||||
self.collect_metrics(state, recv_obj, i)
|
||||
if (
|
||||
self.dump_requests_folder
|
||||
and state.finished
|
||||
and state.obj.log_metrics
|
||||
):
|
||||
self.dump_requests(state, out_dict)
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id if recv_obj.success else None
|
||||
if getattr(state.obj, "return_logprob", False):
|
||||
self.convert_logprob_style(
|
||||
meta_info,
|
||||
state.obj.top_logprobs_num,
|
||||
state.obj.return_text_in_logprobs,
|
||||
recv_obj,
|
||||
i,
|
||||
)
|
||||
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.model_update_result.set_result(recv_obj)
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp.append(recv_obj)
|
||||
# set future if the all results are recevied
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for init parameter update group"
|
||||
self.init_weights_update_group_communicator.handle_recv(recv_obj)
|
||||
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for update weights from distributed"
|
||||
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
||||
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for update weights from distributed"
|
||||
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
|
||||
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
||||
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
||||
elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
|
||||
self.release_memory_occupation_communicator.handle_recv(recv_obj)
|
||||
elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
|
||||
self.resume_memory_occupation_communicator.handle_recv(recv_obj)
|
||||
|
||||
if not isinstance(recv_obj, BatchEmbeddingOut):
|
||||
meta_info.update(
|
||||
{
|
||||
"completion_tokens": recv_obj.completion_tokens[i],
|
||||
"cached_tokens": recv_obj.cached_tokens[i],
|
||||
}
|
||||
)
|
||||
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Invalid object: {recv_obj=}")
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
"embedding": recv_obj.embeddings[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reasons[i] is not None
|
||||
state.event.set()
|
||||
|
||||
if self.enable_metrics and state.obj.log_metrics:
|
||||
self.collect_metrics(state, recv_obj, i)
|
||||
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
|
||||
self.dump_requests(state, out_dict)
|
||||
|
||||
def convert_logprob_style(
|
||||
self,
|
||||
@@ -943,6 +934,20 @@ class TokenizerManager:
|
||||
# Schedule the task to run in the background without awaiting it
|
||||
asyncio.create_task(asyncio.to_thread(background_task))
|
||||
|
||||
def _handle_open_session_req_output(self, recv_obj):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
recv_obj.session_id if recv_obj.success else None
|
||||
)
|
||||
|
||||
def _handle_update_weights_from_disk_req_output(self, recv_obj):
|
||||
if self.server_args.dp_size == 1:
|
||||
self.model_update_result.set_result(recv_obj)
|
||||
else: # self.server_args.dp_size > 1
|
||||
self.model_update_tmp.append(recv_obj)
|
||||
# set future if the all results are recevied
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
|
||||
|
||||
async def print_exception_wrapper(func):
|
||||
"""
|
||||
|
||||
@@ -15,7 +15,7 @@ import urllib.request
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from json import dumps
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -363,3 +363,14 @@ def terminate_process(process):
|
||||
def print_highlight(html_content: str):
|
||||
html_content = str(html_content).replace("\n", "<br>")
|
||||
display(HTML(f"<strong style='color: #00008B;'>{html_content}</strong>"))
|
||||
|
||||
|
||||
class TypeBasedDispatcher:
|
||||
def __init__(self, mapping: List[Tuple[Type, Callable]]):
|
||||
self._mapping = mapping
|
||||
|
||||
def __call__(self, obj: Any):
|
||||
for ty, fn in self._mapping:
|
||||
if isinstance(obj, ty):
|
||||
return fn(obj)
|
||||
raise ValueError(f"Invalid object: {obj}")
|
||||
|
||||
Reference in New Issue
Block a user