Support sgl-router parallel_batch in bench_one_batch_server (#10506)

This commit is contained in:
fzyzcjy
2025-09-16 17:52:57 +08:00
committed by GitHub
parent ae4be601c2
commit 8df7353af3

View File

@@ -48,6 +48,7 @@ class BenchArgs:
profile_steps: int = 3
profile_by_stage: bool = False
dataset_path: str = ""
parallel_batch: bool = False
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
@@ -90,6 +91,7 @@ class BenchArgs:
default=BenchArgs.dataset_path,
help="Path to the dataset.",
)
parser.add_argument("--parallel-batch", action="store_true")
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
@@ -146,6 +148,7 @@ def run_one_case(
profile_steps: int = 3,
profile_by_stage: bool = False,
dataset_path: str = "",
parallel_batch: bool = False,
):
requests.post(url + "/flush_cache")
input_requests = sample_random_requests(
@@ -192,6 +195,7 @@ def run_one_case(
},
"return_logprob": return_logprob,
"stream": True,
**({"parallel_batch": parallel_batch} if parallel_batch else {}),
},
stream=True,
)
@@ -354,6 +358,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
result_filename="",
tokenizer=tokenizer,
dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
)
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
@@ -378,6 +383,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
result_filename=bench_args.result_filename,
tokenizer=tokenizer,
dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
)
)
@@ -404,6 +410,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
profile_steps=bench_args.profile_steps,
profile_by_stage=bench_args.profile_by_stage,
dataset_path=bench_args.dataset_path,
parallel_batch=bench_args.parallel_batch,
)[-1],
)
)