Organize Benchmark (#381)
This commit is contained in:
@@ -2,61 +2,16 @@ import json
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
import requests
|
||||
from data_gen import gen_arguments
|
||||
from tqdm import tqdm
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||
from sglang.utils import dump_state_text
|
||||
|
||||
|
||||
def get_generate(args):
|
||||
# Select backend
|
||||
if args.backend == "vllm":
|
||||
url = f"{args.host}:{args.port}/generate"
|
||||
|
||||
def generate(prompt, max_tokens, stop=None, temperature=0, url=url, n=1):
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"ignore_eos": True,
|
||||
"stop": stop,
|
||||
"stream": False,
|
||||
"n": n,
|
||||
}
|
||||
res = requests.post(url, json=data)
|
||||
assert res.status_code == 200
|
||||
return res.json()["text"][0][len(prompt) :]
|
||||
|
||||
elif args.backend == "guidance":
|
||||
from guidance import gen, models
|
||||
|
||||
model = models.LlamaCpp(
|
||||
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf",
|
||||
n_gpu_layers=-1,
|
||||
n_ctx=4096,
|
||||
)
|
||||
|
||||
def generate(prompt, max_tokens, stop=None):
|
||||
out = (
|
||||
model
|
||||
+ prompt
|
||||
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
|
||||
)
|
||||
return out["answer"]
|
||||
|
||||
# warmup
|
||||
for _ in range(3):
|
||||
generate("Hello!" * 10, max_tokens=64, stop=None)
|
||||
else:
|
||||
raise ValueError(f"Invalid backend: {args.backend}")
|
||||
|
||||
return generate
|
||||
|
||||
|
||||
def multi_turns(generate, qas):
|
||||
s = ""
|
||||
for qa in qas:
|
||||
@@ -75,10 +30,10 @@ def main(args):
|
||||
|
||||
states = [None] * args.num_qa
|
||||
|
||||
generate = get_generate(args)
|
||||
call_generate = partial(get_call_generate(args), temperature=0)
|
||||
|
||||
def get_one_answer(i):
|
||||
states[i] = multi_turns(generate=generate, **multi_qas[i])
|
||||
states[i] = multi_turns(generate=call_generate, **multi_qas[i])
|
||||
|
||||
tic = time.time()
|
||||
if args.parallel == 1:
|
||||
@@ -86,7 +41,12 @@ def main(args):
|
||||
get_one_answer(i)
|
||||
else:
|
||||
with ThreadPoolExecutor(args.parallel) as executor:
|
||||
rets = executor.map(get_one_answer, list(range(len(multi_qas))))
|
||||
rets = list(
|
||||
tqdm(
|
||||
executor.map(get_one_answer, list(range(len(multi_qas)))),
|
||||
total=len(multi_qas),
|
||||
)
|
||||
)
|
||||
for _ in rets:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user