[feat] Refactor session control interface and add CI (#2173)

This commit is contained in:
Ying Sheng
2024-11-25 12:32:51 -08:00
committed by GitHub
parent 5ada33ffa0
commit e1e595d702
8 changed files with 180 additions and 154 deletions

View File

@@ -173,7 +173,6 @@ class DetokenizerManager:
output_strs=output_strs,
meta_info=recv_obj.meta_info,
finished_reason=recv_obj.finished_reason,
session_ids=recv_obj.session_ids,
)
)

View File

@@ -19,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
import uuid
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -55,8 +55,9 @@ class GenerateReqInput:
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Session id info for continual prompting
session_id: Optional[Union[List[str], str]] = None
session_rid: Optional[Union[List[str], str]] = None
session: Optional[
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
] = None
def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or (
@@ -203,7 +204,7 @@ class TokenizedGenerateReqInput:
lora_path: Optional[str] = None # None means just use the base model
# Session id info for continual prompting
session_id: Optional[int] = None
session_id: Optional[str] = None
session_rid: Optional[str] = None
@@ -299,8 +300,6 @@ class BatchTokenIDOut:
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
no_stop_trim: List[bool]
# The updated session unique id
session_ids: List[str]
@dataclass
@@ -313,8 +312,6 @@ class BatchStrOut:
meta_info: List[Dict]
# The finish reason
finished_reason: List[BaseFinishReason]
# The update session unique id
session_ids: List[str]
@dataclass

View File

@@ -542,9 +542,7 @@ class Scheduler:
else:
# Handle sessions
session = self.sessions[recv_req.session_id]
req, new_session_id = session.create_req(recv_req, self.tokenizer)
del self.sessions[recv_req.session_id]
self.sessions[new_session_id] = session
req = session.create_req(recv_req, self.tokenizer)
if isinstance(req.finished_reason, FINISH_ABORT):
self.waiting_queue.append(req)
return
@@ -1188,7 +1186,6 @@ class Scheduler:
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_no_stop_trim = []
output_session_ids = []
else: # embedding or reward model
output_embeddings = []
@@ -1216,7 +1213,6 @@ class Scheduler:
req.sampling_params.spaces_between_special_tokens
)
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
output_session_ids.append(req.session_id)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
@@ -1267,7 +1263,6 @@ class Scheduler:
output_meta_info,
output_finished_reason,
output_no_stop_trim,
output_session_ids,
)
)
else: # embedding or reward model

View File

@@ -26,13 +26,13 @@ class Session:
self.reqs: List[Req] = []
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
# renew session id
self.session_id = uuid.uuid4().hex
if req.session_rid is not None:
while len(self.reqs) > 0:
if self.reqs[-1].rid == req.session_rid:
break
self.reqs = self.reqs[:-1]
else:
self.reqs = []
if len(self.reqs) > 0:
input_ids = (
self.reqs[-1].origin_input_ids
@@ -58,4 +58,4 @@ class Session:
)
else:
self.reqs.append(new_req)
return new_req, self.session_id
return new_req

View File

@@ -216,8 +216,8 @@ class TokenizerManager:
return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num
session_id = obj.session_id
session_rid = obj.session_rid
session_id = obj.session[0] if obj.session else None
session_rid = obj.session[1] if obj.session else None
if len(input_ids) >= self.context_len:
raise ValueError(
@@ -570,13 +570,11 @@ class TokenizerManager:
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
"session_id": recv_obj.session_ids[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)