From 954f4e6bd607ae8ed08cc60dab7c8117e1ff1776 Mon Sep 17 00:00:00 2001 From: DarkSharpness <76582120+DarkSharpness@users.noreply.github.com> Date: Fri, 15 Nov 2024 22:06:19 +0900 Subject: [PATCH] benchmark json schema (#2030) --- benchmark/json_schema/README.md | 15 +++ benchmark/json_schema/bench_sglang.py | 138 ++++++++++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 benchmark/json_schema/README.md create mode 100644 benchmark/json_schema/bench_sglang.py diff --git a/benchmark/json_schema/README.md b/benchmark/json_schema/README.md new file mode 100644 index 000000000..e6790e1fc --- /dev/null +++ b/benchmark/json_schema/README.md @@ -0,0 +1,15 @@ +## Run benchmark + +### Benchmark sglang + +Run Llama-8b + +```bash +python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --port 30000 +``` + +Benchmark + +```bash +python3 bench_sglang.py +``` diff --git a/benchmark/json_schema/bench_sglang.py b/benchmark/json_schema/bench_sglang.py new file mode 100644 index 000000000..05c9783eb --- /dev/null +++ b/benchmark/json_schema/bench_sglang.py @@ -0,0 +1,138 @@ +import argparse +import json +import time +from typing import List, Tuple + +import jsonschema +from datasets import load_dataset + +import sglang as sgl +from sglang.test.test_utils import ( + add_common_sglang_args_and_parse, + select_sglang_backend, +) +from sglang.utils import dump_state_text + + +@sgl.function +def schema_gen(s, message: Tuple[str, str], json_schema: str): + system, user = message + s += sgl.system(system) + s += sgl.user(user) + s += sgl.assistant( + sgl.gen("json_output", temperature=0, max_tokens=256, json_schema=json_schema) + ) + + +def contains_formats(schema, formats: List[str]): + if isinstance(schema, dict): + if schema.get("format", None) in formats: + return True + for value in schema.values(): + if contains_formats(value, formats): + return True + elif isinstance(schema, list): + for item in schema: + if contains_formats(item, formats): + return True + return False + + +def convert_dataset(path: str): + raw_dataset = load_dataset(path) + dataset = [] + for data in raw_dataset["train"]: + messages = data["prompt"] + schema = data["schema"] + obj = json.loads(schema) + + # skip some corrupted examples + if obj.get("type", None) is None: + continue + + # skip schema with format "email" + # which is not supported by outlines for now + if contains_formats(obj, ["email"]): + continue + + system = messages[0] + user = messages[1] + assert system["role"] == "system", "invalid role" + assert user["role"] == "user", "invalid role" + assert len(messages) == 2, "invalid message length" + message = json.dumps(system["content"]), json.dumps(user["content"]) + dataset.append( + { + "message": message, + "json_schema": schema, + } + ) + + return dataset + + +def bench_schema(args): + arguments = convert_dataset(args.data_path) + + if args.num_jsons < 0 or args.num_jsons > len(arguments): + args.num_jsons = len(arguments) + arguments = arguments[: args.num_jsons] + + # Select backend + backend = select_sglang_backend(args) + sgl.set_default_backend(backend) + + # Run requests + tic = time.time() + states = schema_gen.run_batch( + arguments, + temperature=0, + num_threads=args.parallel, + progress_bar=True, + ) + latency = time.time() - tic + + # Check if the outputs are valid + indexs = [] + for i, state in enumerate(states): + try: + schema = json.loads(arguments[i]["json_schema"]) + obj = json.loads(state["json_output"]) + assert jsonschema.validate(obj, schema) is None + except Exception as e: + print(e) + indexs.append(i) + + assert len(indexs) == 0, f"Invalid json outputs: {indexs}" + return states, latency + + +def main(args): + states, latency = bench_schema(args) + + # Compute accuracy + print(f"Latency: {latency:.3f}") + + # Write results + dump_state_text(f"tmp_output_{args.backend}.txt", states) + with open(f"{args.backend}.json", "w") as fout: + for state in states: + fout.write(state["json_output"] + "\n") + + with open(args.result_file, "a") as fout: + value = { + "task": "json_schema", + "backend": args.backend, + "latency": round(latency, 3), + "num_jsons": args.num_jsons, + "parallel": args.parallel, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", type=str, default="NousResearch/json-mode-eval") + parser.add_argument("--num-jsons", type=int, default=-1) + args = add_common_sglang_args_and_parse(parser) + main(args)