Files
sglang/python/sglang/srt/managers/io_struct.py
2024-07-07 01:55:58 -07:00

143 lines
4.4 KiB
Python

"""
The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
"""
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
from sglang.srt.sampling_params import SamplingParams
@dataclass
class GenerateReqInput:
# The input prompt
text: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids
input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The image input
image_data: Optional[Union[List[str], str]] = None
# The sampling_params
sampling_params: Union[List[Dict], Dict] = None
# The request id
rid: Optional[Union[List[str], str]] = None
# Whether to return logprobs
return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob
logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return
top_logprobs_num: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in logprobs
return_text_in_logprobs: bool = False
# Whether to stream output
stream: bool = False
def post_init(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
):
raise ValueError("Either text or input_ids should be provided.")
if self.text is not None:
is_single = isinstance(self.text, str)
else:
is_single = isinstance(self.input_ids[0], int)
self.is_single = is_single
if is_single:
if self.sampling_params is None:
self.sampling_params = {}
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.return_logprob is None:
self.return_logprob = False
if self.logprob_start_len is None:
self.logprob_start_len = 0
if self.top_logprobs_num is None:
self.top_logprobs_num = 0
else:
num = len(self.text) if self.text is not None else len(self.input_ids)
if self.image_data is None:
self.image_data = [None] * num
elif not isinstance(self.image_data, list):
self.image_data = [self.image_data] * num
if self.sampling_params is None:
self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list):
self.sampling_params = [self.sampling_params] * num
if self.rid is None:
self.rid = [uuid.uuid4().hex for _ in range(num)]
else:
if not isinstance(self.rid, list):
raise ValueError("The rid should be a list.")
if self.return_logprob is None:
self.return_logprob = [False] * num
elif not isinstance(self.return_logprob, list):
self.return_logprob = [self.return_logprob] * num
if self.logprob_start_len is None:
self.logprob_start_len = [0] * num
elif not isinstance(self.logprob_start_len, list):
self.logprob_start_len = [self.logprob_start_len] * num
if self.top_logprobs_num is None:
self.top_logprobs_num = [0] * num
elif not isinstance(self.top_logprobs_num, list):
self.top_logprobs_num = [self.top_logprobs_num] * num
@dataclass
class TokenizedGenerateReqInput:
rid: str
input_text: str
input_ids: List[int]
pixel_values: List[float]
image_hash: int
image_size: List[int]
sampling_params: SamplingParams
return_logprob: bool
logprob_start_len: int
top_logprobs_num: int
stream: bool
@dataclass
class BatchTokenIDOut:
rids: List[str]
decoded_texts: List[str]
surr_output_ids: List[List[int]]
read_output_ids: List[List[int]]
skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
@dataclass
class BatchStrOut:
rids: List[str]
output_strs: List[str]
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
@dataclass
class FlushCacheReq:
pass
@dataclass
class AbortReq:
rid: str
@dataclass
class DetokenizeReqInput:
input_ids: List[int]