Organize Benchmark (#381)

This commit is contained in:
Liangsheng Yin
2024-05-05 16:14:17 +08:00
committed by GitHub
parent 183df47282
commit 14522e6a26
36 changed files with 829 additions and 809 deletions

View File

@@ -2,61 +2,16 @@ import json
import time
from argparse import ArgumentParser
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import requests
from data_gen import gen_arguments
from tqdm import tqdm
from vllm.transformers_utils.tokenizer import get_tokenizer
from sglang.test.test_utils import add_common_other_args_and_parse
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text
def get_generate(args):
# Select backend
if args.backend == "vllm":
url = f"{args.host}:{args.port}/generate"
def generate(prompt, max_tokens, stop=None, temperature=0, url=url, n=1):
data = {
"prompt": prompt,
"temperature": temperature,
"max_tokens": max_tokens,
"ignore_eos": True,
"stop": stop,
"stream": False,
"n": n,
}
res = requests.post(url, json=data)
assert res.status_code == 200
return res.json()["text"][0][len(prompt) :]
elif args.backend == "guidance":
from guidance import gen, models
model = models.LlamaCpp(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
n_gpu_layers=-1,
n_ctx=4096,
)
def generate(prompt, max_tokens, stop=None):
out = (
model
+ prompt
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
)
return out["answer"]
# warmup
for _ in range(3):
generate("Hello!" * 10, max_tokens=64, stop=None)
else:
raise ValueError(f"Invalid backend: {args.backend}")
return generate
def multi_turns(generate, qas):
s = ""
for qa in qas:
@@ -75,10 +30,10 @@ def main(args):
states = [None] * args.num_qa
generate = get_generate(args)
call_generate = partial(get_call_generate(args), temperature=0)
def get_one_answer(i):
states[i] = multi_turns(generate=generate, **multi_qas[i])
states[i] = multi_turns(generate=call_generate, **multi_qas[i])
tic = time.time()
if args.parallel == 1:
@@ -86,7 +41,12 @@ def main(args):
get_one_answer(i)
else:
with ThreadPoolExecutor(args.parallel) as executor:
rets = executor.map(get_one_answer, list(range(len(multi_qas))))
rets = list(
tqdm(
executor.map(get_one_answer, list(range(len(multi_qas)))),
total=len(multi_qas),
)
)
for _ in rets:
pass