Support more OpenAI API test (#916)

This commit is contained in:
yichuan~
2024-08-05 07:43:09 +08:00
committed by GitHub
parent bb66cc4c52
commit d53dcf9c98
5 changed files with 230 additions and 59 deletions

View File

@@ -92,7 +92,7 @@ class GenerateReqInput:
for element in parallel_sample_num_list
)
if parallel_sample_num > 1 and (not all_equal):
## TODO cope with the case that the parallel_sample_num is different for different samples
# TODO cope with the case that the parallel_sample_num is different for different samples
raise ValueError(
"The parallel_sample_num should be the same for all samples in sample params."
)
@@ -103,14 +103,19 @@ class GenerateReqInput:
if parallel_sample_num != 1:
# parallel sampling +1 represents the original prefill stage
num = parallel_sample_num + 1
if isinstance(self.text, List):
## suppot batch operation
if isinstance(self.text, list):
# suppot batch operation
self.batch_size = len(self.text)
num = num * len(self.text)
elif isinstance(self.input_ids, list) and isinstance(
self.input_ids[0], list
):
self.batch_size = len(self.input_ids)
num = num * len(self.input_ids)
else:
self.batch_size = 1
else:
## support select operation
# support select operation
num = len(self.text) if self.text is not None else len(self.input_ids)
self.batch_size = num