Return logprob for choices (#87)

This commit is contained in:
Lianmin Zheng
2024-01-23 05:07:30 -08:00
committed by GitHub
parent 9e037c822c
commit 9a16fea012
15 changed files with 161 additions and 112 deletions

View File

@@ -11,8 +11,8 @@ class GenerateReqInput:
image_data: Optional[Union[List[str], str]] = None
sampling_params: Union[List[Dict], Dict] = None
rid: Optional[Union[List[str], str]] = None
return_normalized_logprob: Optional[Union[List[bool], bool]] = None
normalized_logprob_start_len: Optional[Union[List[int], int]] = None
return_logprob: Optional[Union[List[bool], bool]] = None
logprob_start_len: Optional[Union[List[int], int]] = None
stream: bool = False
def post_init(self):
@@ -23,10 +23,10 @@ class GenerateReqInput:
self.sampling_params = {}
if self.rid is None:
self.rid = uuid.uuid4().hex
if self.return_normalized_logprob is None:
self.return_normalized_logprob = False
if self.normalized_logprob_start_len is None:
self.normalized_logprob_start_len = 0
if self.return_logprob is None:
self.return_logprob = False
if self.logprob_start_len is None:
self.logprob_start_len = 0
else:
num = len(self.text)
@@ -45,17 +45,15 @@ class GenerateReqInput:
else:
assert isinstance(self.rid, list)
if self.return_normalized_logprob is None:
self.return_normalized_logprob = [False] * num
elif not isinstance(self.return_normalized_logprob, list):
self.return_normalized_logprob = [self.return_normalized_logprob] * num
if self.return_logprob is None:
self.return_logprob = [False] * num
elif not isinstance(self.return_logprob, list):
self.return_logprob = [self.return_logprob] * num
if self.normalized_logprob_start_len is None:
self.normalized_logprob_start_len = [0] * num
elif not isinstance(self.normalized_logprob_start_len, list):
self.normalized_logprob_start_len = [
self.normalized_logprob_start_len
] * num
if self.logprob_start_len is None:
self.logprob_start_len = [0] * num
elif not isinstance(self.logprob_start_len, list):
self.logprob_start_len = [self.logprob_start_len] * num
@dataclass
@@ -65,8 +63,8 @@ class TokenizedGenerateReqInput:
pixel_values: List[float]
image_hash: int
sampling_params: SamplingParams
return_normalized_logprob: bool
normalized_logprob_start_len: int
return_logprob: bool
logprob_start_len: int
stream: bool