Format Benchmark Code (#399)
This commit is contained in:
@@ -4,12 +4,12 @@ from argparse import ArgumentParser
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse
|
||||
from sglang.utils import dump_state_text
|
||||
from data_gen import gen_arguments
|
||||
from tqdm import tqdm
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from data_gen import gen_arguments
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse
|
||||
from sglang.utils import dump_state_text
|
||||
|
||||
|
||||
def get_generate(args):
|
||||
@@ -61,7 +61,7 @@ def multi_turns(generate, qas):
|
||||
s = ""
|
||||
for qa in qas:
|
||||
s += qa["prompt"]
|
||||
s += generate(s, max_tokens=qa["new_tokens"])
|
||||
s += generate(s, max_tokens=qa["new_tokens"])
|
||||
|
||||
return s
|
||||
|
||||
|
||||
@@ -2,22 +2,22 @@ import json
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from data_gen import gen_arguments
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import dump_state_text
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from data_gen import gen_arguments
|
||||
|
||||
|
||||
@sgl.function
|
||||
def multi_turns(s, qas):
|
||||
for qa in qas:
|
||||
s += qa["prompt"]
|
||||
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
|
||||
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -29,7 +29,11 @@ def main(args):
|
||||
|
||||
tic = time.time()
|
||||
states = multi_turns.run_batch(
|
||||
multi_qas, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True
|
||||
multi_qas,
|
||||
temperature=0,
|
||||
backend=backend,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True,
|
||||
)
|
||||
latency = time.time() - tic
|
||||
|
||||
|
||||
Reference in New Issue
Block a user