Add support for OpenAI API : offline batch(file) processing (#699)
Co-authored-by: hnyls2002 <hnyls2002@gmail.com>
This commit is contained in:
@@ -79,8 +79,26 @@ class GenerateReqInput:
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = 0
|
||||
else:
|
||||
|
||||
parallel_sample_num = self.sampling_params.get("n", 1)
|
||||
parallel_sample_num_list = []
|
||||
if isinstance(self.sampling_params, dict):
|
||||
parallel_sample_num = self.sampling_params.get("n", 1)
|
||||
elif isinstance(self.sampling_params, list):
|
||||
for sp in self.sampling_params:
|
||||
parallel_sample_num = sp.get("n", 1)
|
||||
parallel_sample_num_list.append(parallel_sample_num)
|
||||
parallel_sample_num = max(parallel_sample_num_list)
|
||||
all_equal = all(
|
||||
element == parallel_sample_num
|
||||
for element in parallel_sample_num_list
|
||||
)
|
||||
if parallel_sample_num > 1 and (not all_equal):
|
||||
## TODO cope with the case that the parallel_sample_num is different for different samples
|
||||
raise ValueError(
|
||||
"The parallel_sample_num should be the same for all samples in sample params."
|
||||
)
|
||||
else:
|
||||
parallel_sample_num = 1
|
||||
self.parallel_sample_num = parallel_sample_num
|
||||
|
||||
if parallel_sample_num != 1:
|
||||
# parallel sampling +1 represents the original prefill stage
|
||||
|
||||
Reference in New Issue
Block a user