Logprobs Refractor (#331)

This commit is contained in:
Liangsheng Yin
2024-03-28 14:34:49 +08:00
committed by GitHub
parent 24e59f5350
commit 3842eba5fa
14 changed files with 385 additions and 152 deletions

View File

@@ -19,10 +19,13 @@ class GenerateReqInput:
return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob
logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return
top_logprobs_num: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in logprobs
return_text_in_logprobs: bool = False
# Whether to stream output
stream: bool = False
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
def post_init(self):
is_single = isinstance(self.text, str)
@@ -36,6 +39,8 @@ class GenerateReqInput:
self.return_logprob = False
if self.logprob_start_len is None:
self.logprob_start_len = 0
if self.top_logprobs_num is None:
self.top_logprobs_num = 0
else:
num = len(self.text)
@@ -64,6 +69,11 @@ class GenerateReqInput:
elif not isinstance(self.logprob_start_len, list):
self.logprob_start_len = [self.logprob_start_len] * num
if self.top_logprobs_num is None:
self.top_logprobs_num = [0] * num
elif not isinstance(self.top_logprobs_num, list):
self.top_logprobs_num = [self.top_logprobs_num] * num
@dataclass
class TokenizedGenerateReqInput:
@@ -76,6 +86,7 @@ class TokenizedGenerateReqInput:
sampling_params: SamplingParams
return_logprob: bool
logprob_start_len: int
top_logprobs_num: int
stream: bool