Logprobs Refractor (#331)
This commit is contained in:
@@ -19,10 +19,13 @@ class GenerateReqInput:
|
||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||
# The start location of the prompt for return_logprob
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||
# The number of top logprobs to return
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||
# Whether to detokenize tokens in logprobs
|
||||
return_text_in_logprobs: bool = False
|
||||
# Whether to stream output
|
||||
stream: bool = False
|
||||
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
||||
|
||||
def post_init(self):
|
||||
is_single = isinstance(self.text, str)
|
||||
@@ -36,6 +39,8 @@ class GenerateReqInput:
|
||||
self.return_logprob = False
|
||||
if self.logprob_start_len is None:
|
||||
self.logprob_start_len = 0
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = 0
|
||||
else:
|
||||
num = len(self.text)
|
||||
|
||||
@@ -64,6 +69,11 @@ class GenerateReqInput:
|
||||
elif not isinstance(self.logprob_start_len, list):
|
||||
self.logprob_start_len = [self.logprob_start_len] * num
|
||||
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = [0] * num
|
||||
elif not isinstance(self.top_logprobs_num, list):
|
||||
self.top_logprobs_num = [self.top_logprobs_num] * num
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizedGenerateReqInput:
|
||||
@@ -76,6 +86,7 @@ class TokenizedGenerateReqInput:
|
||||
sampling_params: SamplingParams
|
||||
return_logprob: bool
|
||||
logprob_start_len: int
|
||||
top_logprobs_num: int
|
||||
stream: bool
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user