[feat] Refactor session control interface and add CI (#2173)
This commit is contained in:
@@ -173,7 +173,6 @@ class DetokenizerManager:
|
||||
output_strs=output_strs,
|
||||
meta_info=recv_obj.meta_info,
|
||||
finished_reason=recv_obj.finished_reason,
|
||||
session_ids=recv_obj.session_ids,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
@@ -55,8 +55,9 @@ class GenerateReqInput:
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
|
||||
# Session id info for continual prompting
|
||||
session_id: Optional[Union[List[str], str]] = None
|
||||
session_rid: Optional[Union[List[str], str]] = None
|
||||
session: Optional[
|
||||
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
|
||||
] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
@@ -203,7 +204,7 @@ class TokenizedGenerateReqInput:
|
||||
lora_path: Optional[str] = None # None means just use the base model
|
||||
|
||||
# Session id info for continual prompting
|
||||
session_id: Optional[int] = None
|
||||
session_id: Optional[str] = None
|
||||
session_rid: Optional[str] = None
|
||||
|
||||
|
||||
@@ -299,8 +300,6 @@ class BatchTokenIDOut:
|
||||
meta_info: List[Dict]
|
||||
finished_reason: List[BaseFinishReason]
|
||||
no_stop_trim: List[bool]
|
||||
# The updated session unique id
|
||||
session_ids: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -313,8 +312,6 @@ class BatchStrOut:
|
||||
meta_info: List[Dict]
|
||||
# The finish reason
|
||||
finished_reason: List[BaseFinishReason]
|
||||
# The update session unique id
|
||||
session_ids: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -542,9 +542,7 @@ class Scheduler:
|
||||
else:
|
||||
# Handle sessions
|
||||
session = self.sessions[recv_req.session_id]
|
||||
req, new_session_id = session.create_req(recv_req, self.tokenizer)
|
||||
del self.sessions[recv_req.session_id]
|
||||
self.sessions[new_session_id] = session
|
||||
req = session.create_req(recv_req, self.tokenizer)
|
||||
if isinstance(req.finished_reason, FINISH_ABORT):
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
@@ -1188,7 +1186,6 @@ class Scheduler:
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
output_no_stop_trim = []
|
||||
output_session_ids = []
|
||||
else: # embedding or reward model
|
||||
output_embeddings = []
|
||||
|
||||
@@ -1216,7 +1213,6 @@ class Scheduler:
|
||||
req.sampling_params.spaces_between_special_tokens
|
||||
)
|
||||
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
||||
output_session_ids.append(req.session_id)
|
||||
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
@@ -1267,7 +1263,6 @@ class Scheduler:
|
||||
output_meta_info,
|
||||
output_finished_reason,
|
||||
output_no_stop_trim,
|
||||
output_session_ids,
|
||||
)
|
||||
)
|
||||
else: # embedding or reward model
|
||||
|
||||
@@ -26,13 +26,13 @@ class Session:
|
||||
self.reqs: List[Req] = []
|
||||
|
||||
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
||||
# renew session id
|
||||
self.session_id = uuid.uuid4().hex
|
||||
if req.session_rid is not None:
|
||||
while len(self.reqs) > 0:
|
||||
if self.reqs[-1].rid == req.session_rid:
|
||||
break
|
||||
self.reqs = self.reqs[:-1]
|
||||
else:
|
||||
self.reqs = []
|
||||
if len(self.reqs) > 0:
|
||||
input_ids = (
|
||||
self.reqs[-1].origin_input_ids
|
||||
@@ -58,4 +58,4 @@ class Session:
|
||||
)
|
||||
else:
|
||||
self.reqs.append(new_req)
|
||||
return new_req, self.session_id
|
||||
return new_req
|
||||
|
||||
@@ -216,8 +216,8 @@ class TokenizerManager:
|
||||
return_logprob = obj.return_logprob
|
||||
logprob_start_len = obj.logprob_start_len
|
||||
top_logprobs_num = obj.top_logprobs_num
|
||||
session_id = obj.session_id
|
||||
session_rid = obj.session_rid
|
||||
session_id = obj.session[0] if obj.session else None
|
||||
session_rid = obj.session[1] if obj.session else None
|
||||
|
||||
if len(input_ids) >= self.context_len:
|
||||
raise ValueError(
|
||||
@@ -570,13 +570,11 @@ class TokenizerManager:
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
"session_id": recv_obj.session_ids[i],
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
"session_id": recv_obj.session_ids[i],
|
||||
}
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
|
||||
Reference in New Issue
Block a user