Tiny refactor bench_serving to extract RequestFuncOutput.init_new (#6108)
This commit is contained in:
@@ -73,6 +73,12 @@ class RequestFuncOutput:
|
||||
error: str = ""
|
||||
output_len: int = 0
|
||||
|
||||
@staticmethod
|
||||
def init_new(request_func_input: RequestFuncInput):
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
return output
|
||||
|
||||
|
||||
def remove_prefix(text: str, prefix: str) -> str:
|
||||
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||
@@ -114,8 +120,7 @@ async def async_request_trt_llm(
|
||||
if args.disable_ignore_eos:
|
||||
del payload["min_length"]
|
||||
del payload["end_id"]
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
output = RequestFuncOutput.init_new(request_func_input)
|
||||
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
@@ -186,8 +191,7 @@ async def async_request_openai_completions(
|
||||
}
|
||||
headers = get_auth_headers()
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
output = RequestFuncOutput.init_new(request_func_input)
|
||||
|
||||
generated_text = ""
|
||||
output_len = request_func_input.output_len
|
||||
@@ -269,8 +273,7 @@ async def async_request_truss(
|
||||
}
|
||||
headers = get_auth_headers()
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
output = RequestFuncOutput.init_new(request_func_input)
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
@@ -355,8 +358,7 @@ async def async_request_sglang_generate(
|
||||
|
||||
headers = get_auth_headers()
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
output = RequestFuncOutput.init_new(request_func_input)
|
||||
|
||||
generated_text = ""
|
||||
output_len = request_func_input.output_len
|
||||
|
||||
Reference in New Issue
Block a user