diff --git a/python/sglang/test/run_eval.py b/python/sglang/test/run_eval.py index 85f84c36b..8e935e774 100644 --- a/python/sglang/test/run_eval.py +++ b/python/sglang/test/run_eval.py @@ -10,11 +10,29 @@ import time from sglang.test.simple_eval_common import ( ChatCompletionSampler, + Eval, make_report, set_ulimit, ) +def run_eval_once(args, base_url: str, eval_obj: Eval) -> dict: + sampler = ChatCompletionSampler( + model=args.model, + max_tokens=getattr(args, "max_tokens", 2048), + base_url=base_url, + temperature=getattr(args, "temperature", 0.0), + reasoning_effort=getattr(args, "reasoning_effort", None), + ) + + # Run eval + tic = time.perf_counter() + result = eval_obj(sampler) + latency = time.perf_counter() - tic + + return result, latency, sampler + + def run_eval(args): set_ulimit() @@ -68,18 +86,32 @@ def run_eval(args): else: raise ValueError(f"Invalid eval name: {args.eval_name}") - sampler = ChatCompletionSampler( - model=args.model, - max_tokens=getattr(args, "max_tokens", 2048), - base_url=base_url, - temperature=getattr(args, "temperature", 0.0), - reasoning_effort=getattr(args, "reasoning_effort", None), - ) + if getattr(args, "repeat", 1) == 1: + result, latency, sampler = run_eval_once(args, base_url, eval_obj) + else: + from concurrent.futures import ThreadPoolExecutor - # Run eval - tic = time.perf_counter() - result = eval_obj(sampler) - latency = time.perf_counter() - tic + executor = ThreadPoolExecutor(max_workers=args.repeat) + + futures = [ + executor.submit(run_eval_once, args, base_url, eval_obj) + for _ in range(args.repeat) + ] + + scores_repeat = [] + + for f in futures: + result, latency, sampler = f.result() + scores_repeat.append(result.score) + + mean_score = sum(scores_repeat) / len(scores_repeat) + scores_repeat = [f"{s:.3f}" for s in scores_repeat] + print("=" * 20) + print(f"Repeat: {args.repeat}, mean: {mean_score:.3f}") + print(f"Scores: {scores_repeat}") + print("=" * 20) + + executor.shutdown() # Dump reports metrics = result.metrics | {"score": result.score} @@ -125,6 +157,9 @@ if __name__ == "__main__": type=str, help="Name or path of the model. If not set, the default model will request /v1/models for conf.", ) + parser.add_argument( + "--repeat", type=int, default=1, help="repeat the evaluation n times" + ) parser.add_argument("--eval-name", type=str, default="mmlu") parser.add_argument("--num-examples", type=int) parser.add_argument("--num-threads", type=int, default=512)