Support accurate length control for bench serving (#6594)
This commit is contained in:
@@ -340,7 +340,7 @@ async def async_request_sglang_generate(
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"text": prompt,
|
||||
("text" if isinstance(prompt, str) else "input_ids"): prompt,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
@@ -494,6 +494,7 @@ def get_tokenizer(
|
||||
|
||||
def get_dataset(args, tokenizer):
|
||||
if args.dataset_name == "sharegpt":
|
||||
assert not args.tokenize_prompt
|
||||
input_requests = sample_sharegpt_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
@@ -512,8 +513,10 @@ def get_dataset(args, tokenizer):
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=args.dataset_path,
|
||||
random_sample=args.dataset_name == "random",
|
||||
return_text=not args.tokenize_prompt,
|
||||
)
|
||||
elif args.dataset_name == "generated-shared-prefix":
|
||||
assert not args.tokenize_prompt
|
||||
input_requests = sample_generated_shared_prefix_requests(
|
||||
num_groups=args.gsp_num_groups,
|
||||
prompts_per_group=args.gsp_prompts_per_group,
|
||||
@@ -524,6 +527,7 @@ def get_dataset(args, tokenizer):
|
||||
args=args,
|
||||
)
|
||||
elif args.dataset_name == "mmmu":
|
||||
assert not args.tokenize_prompt
|
||||
input_requests = sample_mmmu_requests(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
@@ -1495,6 +1499,9 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if not hasattr(args, "output_details"):
|
||||
args.output_details = False
|
||||
|
||||
if not hasattr(args, "tokenize_prompt"):
|
||||
args.tokenize_prompt = False
|
||||
|
||||
print(f"benchmark_args={args}")
|
||||
|
||||
# Set global environments
|
||||
@@ -1506,6 +1513,11 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if args.extra_request_body:
|
||||
extra_request_body = json.loads(args.extra_request_body)
|
||||
|
||||
if args.tokenize_prompt:
|
||||
assert (
|
||||
args.backend == "sglang"
|
||||
), "`--tokenize-prompt` only compatible with `--backend sglang` currently"
|
||||
|
||||
# Set url
|
||||
if args.port is None:
|
||||
args.port = {
|
||||
@@ -1812,6 +1824,11 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
help="Number of warmup requests to run before the benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenize-prompt",
|
||||
action="store_true",
|
||||
help="Use integer ids instead of string for inputs. Useful to control prompt lengths accurately",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||
group.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user