[Eval] Add --repeat in run_eval (#11101)
This commit is contained in:
@@ -10,11 +10,29 @@ import time
|
||||
|
||||
from sglang.test.simple_eval_common import (
|
||||
ChatCompletionSampler,
|
||||
Eval,
|
||||
make_report,
|
||||
set_ulimit,
|
||||
)
|
||||
|
||||
|
||||
def run_eval_once(args, base_url: str, eval_obj: Eval) -> dict:
|
||||
sampler = ChatCompletionSampler(
|
||||
model=args.model,
|
||||
max_tokens=getattr(args, "max_tokens", 2048),
|
||||
base_url=base_url,
|
||||
temperature=getattr(args, "temperature", 0.0),
|
||||
reasoning_effort=getattr(args, "reasoning_effort", None),
|
||||
)
|
||||
|
||||
# Run eval
|
||||
tic = time.perf_counter()
|
||||
result = eval_obj(sampler)
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
return result, latency, sampler
|
||||
|
||||
|
||||
def run_eval(args):
|
||||
set_ulimit()
|
||||
|
||||
@@ -68,18 +86,32 @@ def run_eval(args):
|
||||
else:
|
||||
raise ValueError(f"Invalid eval name: {args.eval_name}")
|
||||
|
||||
sampler = ChatCompletionSampler(
|
||||
model=args.model,
|
||||
max_tokens=getattr(args, "max_tokens", 2048),
|
||||
base_url=base_url,
|
||||
temperature=getattr(args, "temperature", 0.0),
|
||||
reasoning_effort=getattr(args, "reasoning_effort", None),
|
||||
)
|
||||
if getattr(args, "repeat", 1) == 1:
|
||||
result, latency, sampler = run_eval_once(args, base_url, eval_obj)
|
||||
else:
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# Run eval
|
||||
tic = time.perf_counter()
|
||||
result = eval_obj(sampler)
|
||||
latency = time.perf_counter() - tic
|
||||
executor = ThreadPoolExecutor(max_workers=args.repeat)
|
||||
|
||||
futures = [
|
||||
executor.submit(run_eval_once, args, base_url, eval_obj)
|
||||
for _ in range(args.repeat)
|
||||
]
|
||||
|
||||
scores_repeat = []
|
||||
|
||||
for f in futures:
|
||||
result, latency, sampler = f.result()
|
||||
scores_repeat.append(result.score)
|
||||
|
||||
mean_score = sum(scores_repeat) / len(scores_repeat)
|
||||
scores_repeat = [f"{s:.3f}" for s in scores_repeat]
|
||||
print("=" * 20)
|
||||
print(f"Repeat: {args.repeat}, mean: {mean_score:.3f}")
|
||||
print(f"Scores: {scores_repeat}")
|
||||
print("=" * 20)
|
||||
|
||||
executor.shutdown()
|
||||
|
||||
# Dump reports
|
||||
metrics = result.metrics | {"score": result.score}
|
||||
@@ -125,6 +157,9 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat", type=int, default=1, help="repeat the evaluation n times"
|
||||
)
|
||||
parser.add_argument("--eval-name", type=str, default="mmlu")
|
||||
parser.add_argument("--num-examples", type=int)
|
||||
parser.add_argument("--num-threads", type=int, default=512)
|
||||
|
||||
Reference in New Issue
Block a user