Tiny refactor bench_serving to extract RequestFuncOutput.init_new (#6108)

This commit is contained in:
fzyzcjy
2025-05-18 08:08:52 +08:00
committed by GitHub
parent 02973cd9a4
commit 26ebb849eb

View File

@@ -73,6 +73,12 @@ class RequestFuncOutput:
error: str = "" error: str = ""
output_len: int = 0 output_len: int = 0
@staticmethod
def init_new(request_func_input: RequestFuncInput):
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
return output
def remove_prefix(text: str, prefix: str) -> str: def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text return text[len(prefix) :] if text.startswith(prefix) else text
@@ -114,8 +120,7 @@ async def async_request_trt_llm(
if args.disable_ignore_eos: if args.disable_ignore_eos:
del payload["min_length"] del payload["min_length"]
del payload["end_id"] del payload["end_id"]
output = RequestFuncOutput() output = RequestFuncOutput.init_new(request_func_input)
output.prompt_len = request_func_input.prompt_len
ttft = 0.0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
@@ -186,8 +191,7 @@ async def async_request_openai_completions(
} }
headers = get_auth_headers() headers = get_auth_headers()
output = RequestFuncOutput() output = RequestFuncOutput.init_new(request_func_input)
output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
output_len = request_func_input.output_len output_len = request_func_input.output_len
@@ -269,8 +273,7 @@ async def async_request_truss(
} }
headers = get_auth_headers() headers = get_auth_headers()
output = RequestFuncOutput() output = RequestFuncOutput.init_new(request_func_input)
output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
ttft = 0.0 ttft = 0.0
@@ -355,8 +358,7 @@ async def async_request_sglang_generate(
headers = get_auth_headers() headers = get_auth_headers()
output = RequestFuncOutput() output = RequestFuncOutput.init_new(request_func_input)
output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
output_len = request_func_input.output_len output_len = request_func_input.output_len