Support accurate length control for bench serving (#6594)

This commit is contained in:
fzyzcjy
2025-05-26 13:46:23 +08:00
committed by GitHub
parent 25be63d0b2
commit 6bebef60a7

View File

@@ -340,7 +340,7 @@ async def async_request_sglang_generate(
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"text": prompt,
("text" if isinstance(prompt, str) else "input_ids"): prompt,
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": request_func_input.output_len,
@@ -494,6 +494,7 @@ def get_tokenizer(
def get_dataset(args, tokenizer):
if args.dataset_name == "sharegpt":
assert not args.tokenize_prompt
input_requests = sample_sharegpt_requests(
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
@@ -512,8 +513,10 @@ def get_dataset(args, tokenizer):
tokenizer=tokenizer,
dataset_path=args.dataset_path,
random_sample=args.dataset_name == "random",
return_text=not args.tokenize_prompt,
)
elif args.dataset_name == "generated-shared-prefix":
assert not args.tokenize_prompt
input_requests = sample_generated_shared_prefix_requests(
num_groups=args.gsp_num_groups,
prompts_per_group=args.gsp_prompts_per_group,
@@ -524,6 +527,7 @@ def get_dataset(args, tokenizer):
args=args,
)
elif args.dataset_name == "mmmu":
assert not args.tokenize_prompt
input_requests = sample_mmmu_requests(
num_requests=args.num_prompts,
tokenizer=tokenizer,
@@ -1495,6 +1499,9 @@ def run_benchmark(args_: argparse.Namespace):
if not hasattr(args, "output_details"):
args.output_details = False
if not hasattr(args, "tokenize_prompt"):
args.tokenize_prompt = False
print(f"benchmark_args={args}")
# Set global environments
@@ -1506,6 +1513,11 @@ def run_benchmark(args_: argparse.Namespace):
if args.extra_request_body:
extra_request_body = json.loads(args.extra_request_body)
if args.tokenize_prompt:
assert (
args.backend == "sglang"
), "`--tokenize-prompt` only compatible with `--backend sglang` currently"
# Set url
if args.port is None:
args.port = {
@@ -1812,6 +1824,11 @@ if __name__ == "__main__":
default=1,
help="Number of warmup requests to run before the benchmark",
)
parser.add_argument(
"--tokenize-prompt",
action="store_true",
help="Use integer ids instead of string for inputs. Useful to control prompt lengths accurately",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument(