diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index ee981982b..a02a3ec3a 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -39,7 +39,6 @@ from transformers import ( PreTrainedTokenizerFast, ) -AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) ASSISTANT_SUFFIX = "Assistant:" global args @@ -51,6 +50,19 @@ def _get_bool_env_var(name: str, default: str = "false") -> bool: return value.lower() in ("true", "1") +def _create_bench_client_session(): + # When the pressure is big, the read buffer could be full before aio thread read + # the content. We increase the read_bufsize from 64K to 10M. + # Define constants for timeout and buffer size for clarity and maintainability + BENCH_AIOHTTP_TIMEOUT_SECONDS = 6 * 60 * 60 # 6 hours + BENCH_AIOHTTP_READ_BUFSIZE_BYTES = 10 * 1024**2 # 10 MB + + aiohttp_timeout = aiohttp.ClientTimeout(total=BENCH_AIOHTTP_TIMEOUT_SECONDS) + return aiohttp.ClientSession( + timeout=aiohttp_timeout, read_bufsize=BENCH_AIOHTTP_READ_BUFSIZE_BYTES + ) + + @dataclass class RequestFuncInput: prompt: str @@ -106,7 +118,7 @@ async def async_request_trt_llm( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with _create_bench_client_session() as session: payload = { "accumulate_tokens": True, "text_input": request_func_input.prompt, @@ -179,7 +191,7 @@ async def async_request_openai_completions( prompt = request_func_input.prompt - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with _create_bench_client_session() as session: payload = { "model": request_func_input.model, "prompt": prompt, @@ -261,7 +273,7 @@ async def async_request_truss( prompt = request_func_input.prompt - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with _create_bench_client_session() as session: payload = { "model": request_func_input.model, "prompt": prompt, @@ -338,7 +350,7 @@ async def async_request_sglang_generate( api_url = request_func_input.api_url prompt = request_func_input.prompt - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with _create_bench_client_session() as session: payload = { ("text" if isinstance(prompt, str) else "input_ids"): prompt, "sampling_params": { @@ -437,7 +449,7 @@ async def async_request_gserver( async def async_request_profile(api_url: str) -> RequestFuncOutput: - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with _create_bench_client_session() as session: output = RequestFuncOutput() try: async with session.post(url=api_url) as response: