Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -56,6 +56,7 @@ class BenchArgs:
|
||||
profile: bool = False
|
||||
skip_warmup: bool = False
|
||||
do_not_exit: bool = False
|
||||
prompt_suffix: str = ""
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -177,6 +178,12 @@ class BenchArgs:
|
||||
action="store_true",
|
||||
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-suffix",
|
||||
type=str,
|
||||
default="",
|
||||
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
@@ -216,6 +223,10 @@ def throughput_test_once(
|
||||
]
|
||||
|
||||
if profile:
|
||||
assert (
|
||||
"SGLANG_TORCH_PROFILER_DIR" in os.environ
|
||||
), "Please set SGLANG_TORCH_PROFILER_DIR."
|
||||
os.makedirs(os.environ["SGLANG_TORCH_PROFILER_DIR"], exist_ok=True)
|
||||
backend.start_profile()
|
||||
|
||||
st = time.perf_counter()
|
||||
@@ -229,6 +240,8 @@ def throughput_test_once(
|
||||
if backend_name == "runtime":
|
||||
gen_out = json.loads(gen_out)
|
||||
|
||||
server_info = backend.get_server_info()
|
||||
|
||||
measurement_results["total_latency"] = latency
|
||||
measurement_results["total_output_tokens"] = sum(
|
||||
o["meta_info"]["completion_tokens"] for o in gen_out
|
||||
@@ -246,6 +259,7 @@ def throughput_test_once(
|
||||
measurement_results["total_input_tokens"]
|
||||
+ measurement_results["total_output_tokens"]
|
||||
) / latency
|
||||
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"]
|
||||
|
||||
return measurement_results
|
||||
|
||||
@@ -361,6 +375,11 @@ def throughput_test(
|
||||
print(
|
||||
"{:<40} {:<10}".format("Total generated tokens:", result["total_output_tokens"])
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Last generation throughput (tok/s):", result["last_gen_throughput"]
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Request throughput (req/s):", result["request_throughput"]
|
||||
|
||||
Reference in New Issue
Block a user