Support more OpenAI API test (#916)
This commit is contained in:
@@ -92,7 +92,7 @@ class GenerateReqInput:
|
||||
for element in parallel_sample_num_list
|
||||
)
|
||||
if parallel_sample_num > 1 and (not all_equal):
|
||||
## TODO cope with the case that the parallel_sample_num is different for different samples
|
||||
# TODO cope with the case that the parallel_sample_num is different for different samples
|
||||
raise ValueError(
|
||||
"The parallel_sample_num should be the same for all samples in sample params."
|
||||
)
|
||||
@@ -103,14 +103,19 @@ class GenerateReqInput:
|
||||
if parallel_sample_num != 1:
|
||||
# parallel sampling +1 represents the original prefill stage
|
||||
num = parallel_sample_num + 1
|
||||
if isinstance(self.text, List):
|
||||
## suppot batch operation
|
||||
if isinstance(self.text, list):
|
||||
# suppot batch operation
|
||||
self.batch_size = len(self.text)
|
||||
num = num * len(self.text)
|
||||
elif isinstance(self.input_ids, list) and isinstance(
|
||||
self.input_ids[0], list
|
||||
):
|
||||
self.batch_size = len(self.input_ids)
|
||||
num = num * len(self.input_ids)
|
||||
else:
|
||||
self.batch_size = 1
|
||||
else:
|
||||
## support select operation
|
||||
# support select operation
|
||||
num = len(self.text) if self.text is not None else len(self.input_ids)
|
||||
self.batch_size = num
|
||||
|
||||
|
||||
Reference in New Issue
Block a user