[feat] Add session control (#2073)
This commit is contained in:
@@ -37,9 +37,12 @@ from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
BatchTokenIDOut,
|
||||
CloseSessionReqInput,
|
||||
FlushCacheReq,
|
||||
GetMemPoolSizeReq,
|
||||
GetMemPoolSizeReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -59,6 +62,7 @@ from sglang.srt.managers.schedule_policy import (
|
||||
PrefillAdder,
|
||||
SchedulePolicy,
|
||||
)
|
||||
from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
@@ -106,6 +110,9 @@ class Scheduler:
|
||||
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
||||
self.enable_metrics = server_args.enable_metrics
|
||||
|
||||
# Session info
|
||||
self.sessions = {}
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
|
||||
@@ -509,6 +516,11 @@ class Scheduler:
|
||||
self.start_profile()
|
||||
else:
|
||||
self.stop_profile()
|
||||
elif isinstance(recv_req, OpenSessionReqInput):
|
||||
session_id = self.open_session(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
||||
elif isinstance(recv_req, CloseSessionReqInput):
|
||||
self.close_session(recv_req)
|
||||
elif isinstance(recv_req, GetMemPoolSizeReq):
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
||||
@@ -520,14 +532,30 @@ class Scheduler:
|
||||
self,
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
req = Req(
|
||||
recv_req.rid,
|
||||
recv_req.input_text,
|
||||
recv_req.input_ids,
|
||||
recv_req.sampling_params,
|
||||
lora_path=recv_req.lora_path,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
||||
req = Req(
|
||||
recv_req.rid,
|
||||
recv_req.input_text,
|
||||
recv_req.input_ids,
|
||||
recv_req.sampling_params,
|
||||
lora_path=recv_req.lora_path,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
if recv_req.session_id is not None:
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
f"Invalid request: session id {recv_req.session_id} does not exist"
|
||||
)
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
else:
|
||||
# Handle sessions
|
||||
session = self.sessions[recv_req.session_id]
|
||||
req, new_session_id = session.create_req(recv_req, self.tokenizer)
|
||||
del self.sessions[recv_req.session_id]
|
||||
self.sessions[new_session_id] = session
|
||||
if isinstance(req.finished_reason, FINISH_ABORT):
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
|
||||
# Image inputs
|
||||
if recv_req.image_inputs is not None:
|
||||
@@ -1151,6 +1179,7 @@ class Scheduler:
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
output_no_stop_trim = []
|
||||
output_session_ids = []
|
||||
else: # embedding or reward model
|
||||
output_embeddings = []
|
||||
|
||||
@@ -1178,6 +1207,7 @@ class Scheduler:
|
||||
req.sampling_params.spaces_between_special_tokens
|
||||
)
|
||||
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
||||
output_session_ids.append(req.session_id)
|
||||
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
@@ -1228,6 +1258,7 @@ class Scheduler:
|
||||
output_meta_info,
|
||||
output_finished_reason,
|
||||
output_no_stop_trim,
|
||||
output_session_ids,
|
||||
)
|
||||
)
|
||||
else: # embedding or reward model
|
||||
@@ -1330,6 +1361,25 @@ class Scheduler:
|
||||
)
|
||||
logger.info("Profiler is done")
|
||||
|
||||
def open_session(self, recv_req: OpenSessionReqInput) -> str:
|
||||
# handle error
|
||||
session_id = recv_req.session_id
|
||||
if session_id in self.sessions:
|
||||
logger.warning(f"session id {session_id} already exist, cannot open.")
|
||||
else:
|
||||
self.sessions[session_id] = Session(
|
||||
recv_req.capacity_of_str_len, session_id
|
||||
)
|
||||
return session_id
|
||||
|
||||
def close_session(self, recv_req: CloseSessionReqInput):
|
||||
# handle error
|
||||
session_id = recv_req.session_id
|
||||
if session_id not in self.sessions:
|
||||
logger.warning(f"session id {session_id} does not exist, cannot delete.")
|
||||
else:
|
||||
del self.sessions[session_id]
|
||||
|
||||
|
||||
def run_scheduler_process(
|
||||
server_args: ServerArgs,
|
||||
|
||||
Reference in New Issue
Block a user