[Session] Update session control interface (#2635)
This commit is contained in:
@@ -22,7 +22,7 @@ import warnings
|
||||
from collections import deque
|
||||
from concurrent import futures
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
import setproctitle
|
||||
@@ -498,8 +498,10 @@ class Scheduler:
|
||||
else:
|
||||
self.stop_profile()
|
||||
elif isinstance(recv_req, OpenSessionReqInput):
|
||||
session_id = self.open_session(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
||||
session_id, success = self.open_session(recv_req)
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
OpenSessionReqOutput(session_id=session_id, success=success)
|
||||
)
|
||||
elif isinstance(recv_req, CloseSessionReqInput):
|
||||
self.close_session(recv_req)
|
||||
else:
|
||||
@@ -510,7 +512,11 @@ class Scheduler:
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
# Create a new request
|
||||
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
||||
if (
|
||||
recv_req.session_params is None
|
||||
or recv_req.session_params.id is None
|
||||
or recv_req.session_params.id not in self.sessions
|
||||
):
|
||||
|
||||
if recv_req.input_embeds is not None:
|
||||
# Generate fake input_ids based on the length of input_embeds
|
||||
@@ -532,15 +538,18 @@ class Scheduler:
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
if recv_req.session_id is not None:
|
||||
if (
|
||||
recv_req.session_params is not None
|
||||
and recv_req.session_params.id is not None
|
||||
):
|
||||
req.finished_reason = FINISH_ABORT(
|
||||
f"Invalid request: session id {recv_req.session_id} does not exist"
|
||||
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
||||
)
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
else:
|
||||
# Create a new request from a previsou session
|
||||
session = self.sessions[recv_req.session_id]
|
||||
# Create a new request from a previous session
|
||||
session = self.sessions[recv_req.session_params.id]
|
||||
req = session.create_req(recv_req, self.tokenizer)
|
||||
if isinstance(req.finished_reason, FINISH_ABORT):
|
||||
self.waiting_queue.append(req)
|
||||
@@ -1500,16 +1509,20 @@ class Scheduler:
|
||||
)
|
||||
logger.info("Profiler is done")
|
||||
|
||||
def open_session(self, recv_req: OpenSessionReqInput) -> str:
|
||||
def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
|
||||
# handle error
|
||||
session_id = recv_req.session_id
|
||||
if session_id in self.sessions:
|
||||
logger.warning(f"session id {session_id} already exist, cannot open.")
|
||||
return session_id, False
|
||||
elif session_id is None:
|
||||
logger.warning(f"session id is None, cannot open.")
|
||||
return session_id, False
|
||||
else:
|
||||
self.sessions[session_id] = Session(
|
||||
recv_req.capacity_of_str_len, session_id
|
||||
)
|
||||
return session_id
|
||||
return session_id, True
|
||||
|
||||
def close_session(self, recv_req: CloseSessionReqInput):
|
||||
# handle error
|
||||
|
||||
Reference in New Issue
Block a user