Organize Benchmark (#381)

This commit is contained in:
Liangsheng Yin
2024-05-05 16:14:17 +08:00
committed by GitHub
parent 183df47282
commit 14522e6a26
36 changed files with 829 additions and 809 deletions

View File

@@ -7,10 +7,7 @@ from functools import partial
from tqdm import tqdm
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
from sglang.test.test_utils import (
add_common_other_args_and_parse,
call_generate_outlines,
)
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
@@ -50,41 +47,11 @@ def main(args):
states = [None] * len(arguments)
# Select backend
if args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
generate = partial(call_generate_outlines, url=url, temperature=0)
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop=None, regex=None):
out = (
model
+ prompt
+ gen(
name="answer",
max_tokens=max_tokens,
temperature=0,
stop=stop,
regex=regex,
)
)
return out["answer"]
# warmup
for _ in range(3):
generate("Hello!" * 10, max_tokens=64, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
call_generate = partial(get_call_generate(args), temperature=0)
# Run requests
def get_one_answer(i):
states[i] = json_decode(generate=generate, **arguments[i])
states[i] = json_decode(generate=call_generate, **arguments[i])
tic = time.time()
if args.parallel == 1:
@@ -92,7 +59,12 @@ def main(args):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = executor.map(get_one_answer, list(range(len(arguments))))
rets = list(
tqdm(
executor.map(get_one_answer, list(range(len(arguments)))),
total=len(arguments),
)
)
for _ in rets:
pass