Return logprob for choices (#87)
This commit is contained in:
@@ -11,8 +11,8 @@ class GenerateReqInput:
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
return_normalized_logprob: Optional[Union[List[bool], bool]] = None
|
||||
normalized_logprob_start_len: Optional[Union[List[int], int]] = None
|
||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||
stream: bool = False
|
||||
|
||||
def post_init(self):
|
||||
@@ -23,10 +23,10 @@ class GenerateReqInput:
|
||||
self.sampling_params = {}
|
||||
if self.rid is None:
|
||||
self.rid = uuid.uuid4().hex
|
||||
if self.return_normalized_logprob is None:
|
||||
self.return_normalized_logprob = False
|
||||
if self.normalized_logprob_start_len is None:
|
||||
self.normalized_logprob_start_len = 0
|
||||
if self.return_logprob is None:
|
||||
self.return_logprob = False
|
||||
if self.logprob_start_len is None:
|
||||
self.logprob_start_len = 0
|
||||
else:
|
||||
num = len(self.text)
|
||||
|
||||
@@ -45,17 +45,15 @@ class GenerateReqInput:
|
||||
else:
|
||||
assert isinstance(self.rid, list)
|
||||
|
||||
if self.return_normalized_logprob is None:
|
||||
self.return_normalized_logprob = [False] * num
|
||||
elif not isinstance(self.return_normalized_logprob, list):
|
||||
self.return_normalized_logprob = [self.return_normalized_logprob] * num
|
||||
if self.return_logprob is None:
|
||||
self.return_logprob = [False] * num
|
||||
elif not isinstance(self.return_logprob, list):
|
||||
self.return_logprob = [self.return_logprob] * num
|
||||
|
||||
if self.normalized_logprob_start_len is None:
|
||||
self.normalized_logprob_start_len = [0] * num
|
||||
elif not isinstance(self.normalized_logprob_start_len, list):
|
||||
self.normalized_logprob_start_len = [
|
||||
self.normalized_logprob_start_len
|
||||
] * num
|
||||
if self.logprob_start_len is None:
|
||||
self.logprob_start_len = [0] * num
|
||||
elif not isinstance(self.logprob_start_len, list):
|
||||
self.logprob_start_len = [self.logprob_start_len] * num
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -65,8 +63,8 @@ class TokenizedGenerateReqInput:
|
||||
pixel_values: List[float]
|
||||
image_hash: int
|
||||
sampling_params: SamplingParams
|
||||
return_normalized_logprob: bool
|
||||
normalized_logprob_start_len: int
|
||||
return_logprob: bool
|
||||
logprob_start_len: int
|
||||
stream: bool
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user