[feat] Add session control (#2073)

This commit is contained in:
Ying Sheng
2024-11-20 00:36:53 -08:00
committed by GitHub
parent 63a395b985
commit 5942dfc00a
8 changed files with 348 additions and 8 deletions

View File

@@ -37,9 +37,12 @@ from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
BatchTokenIDOut,
CloseSessionReqInput,
FlushCacheReq,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
@@ -59,6 +62,7 @@ from sglang.srt.managers.schedule_policy import (
PrefillAdder,
SchedulePolicy,
)
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
@@ -106,6 +110,9 @@ class Scheduler:
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics
# Session info
self.sessions = {}
# Init inter-process communication
context = zmq.Context(2)
@@ -509,6 +516,11 @@ class Scheduler:
self.start_profile()
else:
self.stop_profile()
elif isinstance(recv_req, OpenSessionReqInput):
session_id = self.open_session(recv_req)
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
elif isinstance(recv_req, CloseSessionReqInput):
self.close_session(recv_req)
elif isinstance(recv_req, GetMemPoolSizeReq):
self.send_to_tokenizer.send_pyobj(
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
@@ -520,14 +532,30 @@ class Scheduler:
self,
recv_req: TokenizedGenerateReqInput,
):
req = Req(
recv_req.rid,
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
lora_path=recv_req.lora_path,
)
req.tokenizer = self.tokenizer
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
req = Req(
recv_req.rid,
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
lora_path=recv_req.lora_path,
)
req.tokenizer = self.tokenizer
if recv_req.session_id is not None:
req.finished_reason = FINISH_ABORT(
f"Invalid request: session id {recv_req.session_id} does not exist"
)
self.waiting_queue.append(req)
return
else:
# Handle sessions
session = self.sessions[recv_req.session_id]
req, new_session_id = session.create_req(recv_req, self.tokenizer)
del self.sessions[recv_req.session_id]
self.sessions[new_session_id] = session
if isinstance(req.finished_reason, FINISH_ABORT):
self.waiting_queue.append(req)
return
# Image inputs
if recv_req.image_inputs is not None:
@@ -1151,6 +1179,7 @@ class Scheduler:
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_no_stop_trim = []
output_session_ids = []
else: # embedding or reward model
output_embeddings = []
@@ -1178,6 +1207,7 @@ class Scheduler:
req.sampling_params.spaces_between_special_tokens
)
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
output_session_ids.append(req.session_id)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
@@ -1228,6 +1258,7 @@ class Scheduler:
output_meta_info,
output_finished_reason,
output_no_stop_trim,
output_session_ids,
)
)
else: # embedding or reward model
@@ -1330,6 +1361,25 @@ class Scheduler:
)
logger.info("Profiler is done")
def open_session(self, recv_req: OpenSessionReqInput) -> str:
# handle error
session_id = recv_req.session_id
if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.")
else:
self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id
)
return session_id
def close_session(self, recv_req: CloseSessionReqInput):
# handle error
session_id = recv_req.session_id
if session_id not in self.sessions:
logger.warning(f"session id {session_id} does not exist, cannot delete.")
else:
del self.sessions[session_id]
def run_scheduler_process(
server_args: ServerArgs,