Add city doc benchmark mode (#129)

This commit is contained in:
Liangsheng Yin
2024-02-01 13:38:47 +08:00
committed by GitHub
parent c7af9f7393
commit 79cb018e4b
4 changed files with 268 additions and 22 deletions

View File

@@ -7,7 +7,7 @@ from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
from sglang.utils import dump_state_text, read_jsonl
# there are some FSM bugs with json regex converted from pydantic model
# here use a string regex instead
@@ -29,6 +29,16 @@ character_regex = (
+ r"""\}"""
)
city_regex = (
r"""\{\n"""
+ r""" "name": "[\w\d\s]{1,16}",\n"""
+ r""" "country": "[\w\d\s]{1,16}",\n"""
+ r""" "latitude": [-+]?[0-9]*\.?[0-9]{0,2},\n"""
+ r""" "population": [-+]?[0-9]{1,9},\n"""
+ r""" "top 3 landmarks": \["[\w\d\s]{1,16}", "[\w\d\s]{1,16}", "[\w\d\s]{1,16}"\]\n"""
+ r"""\}"""
)
# fmt: off
@sgl.function
def character_gen(s, name):
@@ -36,6 +46,38 @@ def character_gen(s, name):
s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
# fmt: on
# fmt: off
@sgl.function
def city_gen(s, document):
s += "Please extract the information of a city from the following wikipedia page.\n"
s += "Page begin.\n" + document + "Page end.\n"
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += sgl.gen("json_output",max_tokens=256, regex=city_regex)
# fmt: on
def bench_city_doc(args):
arguments = []
for line in read_jsonl(args.data_path):
arguments.append({"document": line["document"]})
arguments = arguments[: args.num_jsons]
# Select backend
backend = select_sglang_backend(args)
sgl.set_default_backend(backend)
# Run requests
tic = time.time()
states = city_gen.run_batch(
arguments,
temperature=0,
num_threads=args.parallel,
progress_bar=(args.parallel == 1),
)
latency = time.time() - tic
return states, latency
def bench_character(args):
arguments = []
@@ -62,14 +104,19 @@ def bench_character(args):
def main(args):
states, latency = bench_character(args)
if args.mode == "character":
args.data_path = "dataset.txt"
states, latency = bench_character(args)
elif args.mode == "city":
args.data_path = "questions.jsonl"
states, latency = bench_city_doc(args)
# Compute accuracy
print(f"Latency: {latency:.3f}")
# Write results
dump_state_text(f"tmp_output_{args.backend}.txt", states)
with open(f"{args.backend}.json", "w") as fout:
dump_state_text(f"tmp_output_{args.backend}_{args.mode}.txt", states)
with open(f"{args.backend}_{args.mode}.json", "w") as fout:
for state in states:
fout.write(state["json_output"] + "\n")
@@ -79,6 +126,7 @@ def main(args):
"backend": args.backend,
"latency": round(latency, 3),
"num_jsons": args.num_jsons,
"mode": args.mode,
"parallel": args.parallel,
}
fout.write(json.dumps(value) + "\n")
@@ -86,7 +134,10 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, default="dataset.txt")
parser.add_argument("--data-path", type=str)
parser.add_argument("--num-jsons", type=int, default=50)
parser.add_argument(
"--mode", type=str, default="character", choices=["character", "city"]
)
args = add_common_sglang_args_and_parse(parser)
main(args)