[Session] Update session control interface (#2635)

This commit is contained in:
Ying Sheng
2024-12-29 02:10:27 -08:00
committed by GitHub
parent 9c05c6898e
commit e0e09fceeb
6 changed files with 531 additions and 91 deletions

View File

@@ -22,7 +22,7 @@ import warnings
from collections import deque
from concurrent import futures
from types import SimpleNamespace
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
import psutil
import setproctitle
@@ -498,8 +498,10 @@ class Scheduler:
else:
self.stop_profile()
elif isinstance(recv_req, OpenSessionReqInput):
session_id = self.open_session(recv_req)
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
session_id, success = self.open_session(recv_req)
self.send_to_tokenizer.send_pyobj(
OpenSessionReqOutput(session_id=session_id, success=success)
)
elif isinstance(recv_req, CloseSessionReqInput):
self.close_session(recv_req)
else:
@@ -510,7 +512,11 @@ class Scheduler:
recv_req: TokenizedGenerateReqInput,
):
# Create a new request
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
if (
recv_req.session_params is None
or recv_req.session_params.id is None
or recv_req.session_params.id not in self.sessions
):
if recv_req.input_embeds is not None:
# Generate fake input_ids based on the length of input_embeds
@@ -532,15 +538,18 @@ class Scheduler:
)
req.tokenizer = self.tokenizer
if recv_req.session_id is not None:
if (
recv_req.session_params is not None
and recv_req.session_params.id is not None
):
req.finished_reason = FINISH_ABORT(
f"Invalid request: session id {recv_req.session_id} does not exist"
f"Invalid request: session id {recv_req.session_params.id} does not exist"
)
self.waiting_queue.append(req)
return
else:
# Create a new request from a previsou session
session = self.sessions[recv_req.session_id]
# Create a new request from a previous session
session = self.sessions[recv_req.session_params.id]
req = session.create_req(recv_req, self.tokenizer)
if isinstance(req.finished_reason, FINISH_ABORT):
self.waiting_queue.append(req)
@@ -1500,16 +1509,20 @@ class Scheduler:
)
logger.info("Profiler is done")
def open_session(self, recv_req: OpenSessionReqInput) -> str:
def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
# handle error
session_id = recv_req.session_id
if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.")
return session_id, False
elif session_id is None:
logger.warning(f"session id is None, cannot open.")
return session_id, False
else:
self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id
)
return session_id
return session_id, True
def close_session(self, recv_req: CloseSessionReqInput):
# handle error