diff --git a/python/sglang/eval/loogle_eval.py b/python/sglang/eval/loogle_eval.py index 22fc70541..895362cd1 100644 --- a/python/sglang/eval/loogle_eval.py +++ b/python/sglang/eval/loogle_eval.py @@ -73,6 +73,8 @@ async def benchmark(args): tasks: List[asyncio.Task] = [] for idx, ex in enumerate(dataset): + if idx >= args.num_prompts: + break tasks.append( asyncio.create_task( fetch_response( @@ -103,6 +105,8 @@ def analyse(args): hyps: List[str] = [] refs: List[str] = [] for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")): + if idx >= args.num_prompts: + break pkl_file = output_dir / f"response_{idx}.pkl" if not pkl_file.exists(): raise FileNotFoundError(pkl_file) @@ -150,6 +154,9 @@ if __name__ == "__main__": parser.add_argument( "--output-dir", default="tmp-output-dir", help="Directory for cached responses" ) + parser.add_argument( + "--num-prompts", type=int, default=10000, help="Number of prompts to run" + ) args = parser.parse_args() asyncio.run(benchmark(args))