diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 348307524..a2706015d 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -54,6 +54,7 @@ class RequestFuncOutput: itl: List[float] = field(default_factory=list) # List of inter-token latencies prompt_len: int = 0 error: str = "" + output_len: int = 0 def remove_prefix(text: str, prefix: str) -> str: @@ -189,6 +190,7 @@ async def async_request_openai_completions( output.generated_text = generated_text output.success = True output.latency = latency + output.output_len = request_func_input.output_len else: output.error = response.reason or "" output.success = False @@ -451,6 +453,7 @@ def calculate_metrics( outputs: List[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, + backend: str, ) -> Tuple[BenchmarkMetrics, List[int]]: actual_output_lens: List[int] = [] total_input = 0 @@ -460,13 +463,16 @@ def calculate_metrics( ttfts: List[float] = [] for i in range(len(outputs)): if outputs[i].success: - # We use the tokenizer to count the number of output tokens for all - # serving backends instead of looking at len(outputs[i].itl) since - # multiple output tokens may be bundled together - # Note : this may inflate the output token count slightly - output_len = len( - tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids - ) + # We use the tokenizer solely to count output tokens for the TensorRT LLM backend, + # as it lacks `ignore_eos` support. + if backend == "trt": + output_len = len( + tokenizer( + outputs[i].generated_text, add_special_tokens=False + ).input_ids + ) + else: + output_len = outputs[i].output_len actual_output_lens.append(output_len) total_input += input_requests[i][1] if output_len > 1: @@ -571,9 +577,11 @@ async def benchmark( outputs=outputs, dur_s=benchmark_duration, tokenizer=tokenizer, + backend=backend, ) print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) + print("{:<40} {:<10}".format("Backend:", backend)) print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))