Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -56,6 +56,7 @@ class BenchArgs:
profile: bool = False
skip_warmup: bool = False
do_not_exit: bool = False
prompt_suffix: str = ""
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
@@ -177,6 +178,12 @@ class BenchArgs:
action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
)
parser.add_argument(
"--prompt-suffix",
type=str,
default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
@@ -216,6 +223,10 @@ def throughput_test_once(
]
if profile:
assert (
"SGLANG_TORCH_PROFILER_DIR" in os.environ
), "Please set SGLANG_TORCH_PROFILER_DIR."
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
backend.start_profile()
st = time.perf_counter()
@@ -229,6 +240,8 @@ def throughput_test_once(
if backend_name == "runtime":
gen_out = json.loads(gen_out)
server_info = backend.get_server_info()
measurement_results["total_latency"] = latency
measurement_results["total_output_tokens"] = sum(
o["meta_info"]["completion_tokens"] for o in gen_out
@@ -246,6 +259,7 @@ def throughput_test_once(
measurement_results["total_input_tokens"]
+ measurement_results["total_output_tokens"]
) / latency
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
return measurement_results
@@ -361,6 +375,11 @@ def throughput_test(
print(
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
)
print(
"{:<40} {:<10.2f}".format(
"Last generation throughput (tok/s):", result["last_gen_throughput"]
)
)
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", result["request_throughput"]