[script] update loogle test (#7975)
This commit is contained in:
@@ -73,6 +73,8 @@ async def benchmark(args):
|
||||
|
||||
tasks: List[asyncio.Task] = []
|
||||
for idx, ex in enumerate(dataset):
|
||||
if idx >= args.num_prompts:
|
||||
break
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
fetch_response(
|
||||
@@ -103,6 +105,8 @@ def analyse(args):
|
||||
hyps: List[str] = []
|
||||
refs: List[str] = []
|
||||
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
|
||||
if idx >= args.num_prompts:
|
||||
break
|
||||
pkl_file = output_dir / f"response_{idx}.pkl"
|
||||
if not pkl_file.exists():
|
||||
raise FileNotFoundError(pkl_file)
|
||||
@@ -150,6 +154,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=10000, help="Number of prompts to run"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(benchmark(args))
|
||||
|
||||
Reference in New Issue
Block a user