Format Benchmark Code (#399)
This commit is contained in:
@@ -3,12 +3,16 @@ import json
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import read_jsonl
|
||||
|
||||
|
||||
def get_one_example(lines, i, include_answer):
|
||||
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||
if include_answer:
|
||||
ret += lines[i]["endings"][lines[i]["label"]]
|
||||
return ret
|
||||
@@ -31,21 +35,18 @@ def main(args):
|
||||
questions = []
|
||||
choices = []
|
||||
labels = []
|
||||
for i in range(len(lines[:args.num_questions])):
|
||||
for i in range(len(lines[: args.num_questions])):
|
||||
questions.append(get_one_example(lines, i, False))
|
||||
choices.append(lines[i]["endings"])
|
||||
labels.append(lines[i]["label"])
|
||||
arguments = [
|
||||
{"question": q, "choices": c}
|
||||
for q, c in zip(questions, choices)
|
||||
]
|
||||
arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
|
||||
|
||||
#####################################
|
||||
######### SGL Program Begin #########
|
||||
#####################################
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
@sgl.function
|
||||
def few_shot_hellaswag(s, question, choices):
|
||||
s += few_shot_examples + question
|
||||
@@ -61,7 +62,12 @@ def main(args):
|
||||
# Run requests
|
||||
tic = time.time()
|
||||
rets = few_shot_hellaswag.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True)
|
||||
arguments,
|
||||
temperature=0,
|
||||
backend=backend,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True,
|
||||
)
|
||||
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
|
||||
latency = time.time() - tic
|
||||
|
||||
@@ -82,7 +88,7 @@ def main(args):
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user