From fda6bb78dad242c8d78e5af86eac9201cf2036ad Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 1 Apr 2025 15:10:56 -0700 Subject: [PATCH] update bench_serving (#4958) --- python/sglang/bench_serving.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 95cd6a392..6a8d4d00a 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -44,6 +44,12 @@ ASSISTANT_SUFFIX = "Assistant:" global args +# don't want to import sglang package here +def _get_bool_env_var(name: str, default: str = "false") -> bool: + value = os.getenv(name, default) + return value.lower() in ("true", "1") + + @dataclass class RequestFuncInput: prompt: str @@ -969,6 +975,7 @@ async def benchmark( extra_request_body: Dict[str, Any], profile: bool, pd_seperated: bool = False, + flush_cache: bool = False, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -1012,7 +1019,7 @@ async def benchmark( print("Initial test run completed. Starting main benchmark run...") # Flush cache - if "sglang" in backend: + if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache: requests.post(base_url + "/flush_cache", headers=get_auth_headers()) time.sleep(1.0) @@ -1347,6 +1354,10 @@ def run_benchmark(args_: argparse.Namespace): tokenizer = get_tokenizer(tokenizer_id) input_requests = get_dataset(args, tokenizer) + # compatible with SimpleNamespace + if not hasattr(args, "flush_cache"): + args.flush_cache = False + return asyncio.run( benchmark( backend=backend, @@ -1362,6 +1373,7 @@ def run_benchmark(args_: argparse.Namespace): extra_request_body=extra_request_body, profile=args.profile, pd_seperated=args.pd_seperated, + flush_cache=args.flush_cache, ) ) @@ -1543,6 +1555,11 @@ if __name__ == "__main__": action="store_true", help="Benchmark PD disaggregation server", ) + parser.add_argument( + "--flush-cache", + action="store_true", + help="Flush the cache before running the benchmark", + ) group = parser.add_argument_group("generated-shared-prefix dataset arguments") group.add_argument(