From 6bebef60a7c385062ba858d8f28a97efbbb4ced5 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 26 May 2025 13:46:23 +0800 Subject: [PATCH] Support accurate length control for bench serving (#6594) --- python/sglang/bench_serving.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 343af051a..0737fcc8c 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -340,7 +340,7 @@ async def async_request_sglang_generate( async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: payload = { - "text": prompt, + ("text" if isinstance(prompt, str) else "input_ids"): prompt, "sampling_params": { "temperature": 0.0, "max_new_tokens": request_func_input.output_len, @@ -494,6 +494,7 @@ def get_tokenizer( def get_dataset(args, tokenizer): if args.dataset_name == "sharegpt": + assert not args.tokenize_prompt input_requests = sample_sharegpt_requests( dataset_path=args.dataset_path, num_requests=args.num_prompts, @@ -512,8 +513,10 @@ def get_dataset(args, tokenizer): tokenizer=tokenizer, dataset_path=args.dataset_path, random_sample=args.dataset_name == "random", + return_text=not args.tokenize_prompt, ) elif args.dataset_name == "generated-shared-prefix": + assert not args.tokenize_prompt input_requests = sample_generated_shared_prefix_requests( num_groups=args.gsp_num_groups, prompts_per_group=args.gsp_prompts_per_group, @@ -524,6 +527,7 @@ def get_dataset(args, tokenizer): args=args, ) elif args.dataset_name == "mmmu": + assert not args.tokenize_prompt input_requests = sample_mmmu_requests( num_requests=args.num_prompts, tokenizer=tokenizer, @@ -1495,6 +1499,9 @@ def run_benchmark(args_: argparse.Namespace): if not hasattr(args, "output_details"): args.output_details = False + if not hasattr(args, "tokenize_prompt"): + args.tokenize_prompt = False + print(f"benchmark_args={args}") # Set global environments @@ -1506,6 +1513,11 @@ def run_benchmark(args_: argparse.Namespace): if args.extra_request_body: extra_request_body = json.loads(args.extra_request_body) + if args.tokenize_prompt: + assert ( + args.backend == "sglang" + ), "`--tokenize-prompt` only compatible with `--backend sglang` currently" + # Set url if args.port is None: args.port = { @@ -1812,6 +1824,11 @@ if __name__ == "__main__": default=1, help="Number of warmup requests to run before the benchmark", ) + parser.add_argument( + "--tokenize-prompt", + action="store_true", + help="Use integer ids instead of string for inputs. Useful to control prompt lengths accurately", + ) group = parser.add_argument_group("generated-shared-prefix dataset arguments") group.add_argument(