update bench_serving (#4958)
This commit is contained in:
@@ -44,6 +44,12 @@ ASSISTANT_SUFFIX = "Assistant:"
|
|||||||
global args
|
global args
|
||||||
|
|
||||||
|
|
||||||
|
# don't want to import sglang package here
|
||||||
|
def _get_bool_env_var(name: str, default: str = "false") -> bool:
|
||||||
|
value = os.getenv(name, default)
|
||||||
|
return value.lower() in ("true", "1")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RequestFuncInput:
|
class RequestFuncInput:
|
||||||
prompt: str
|
prompt: str
|
||||||
@@ -969,6 +975,7 @@ async def benchmark(
|
|||||||
extra_request_body: Dict[str, Any],
|
extra_request_body: Dict[str, Any],
|
||||||
profile: bool,
|
profile: bool,
|
||||||
pd_seperated: bool = False,
|
pd_seperated: bool = False,
|
||||||
|
flush_cache: bool = False,
|
||||||
):
|
):
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
@@ -1012,7 +1019,7 @@ async def benchmark(
|
|||||||
print("Initial test run completed. Starting main benchmark run...")
|
print("Initial test run completed. Starting main benchmark run...")
|
||||||
|
|
||||||
# Flush cache
|
# Flush cache
|
||||||
if "sglang" in backend:
|
if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
|
||||||
requests.post(base_url + "/flush_cache", headers=get_auth_headers())
|
requests.post(base_url + "/flush_cache", headers=get_auth_headers())
|
||||||
|
|
||||||
time.sleep(1.0)
|
time.sleep(1.0)
|
||||||
@@ -1347,6 +1354,10 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
tokenizer = get_tokenizer(tokenizer_id)
|
tokenizer = get_tokenizer(tokenizer_id)
|
||||||
input_requests = get_dataset(args, tokenizer)
|
input_requests = get_dataset(args, tokenizer)
|
||||||
|
|
||||||
|
# compatible with SimpleNamespace
|
||||||
|
if not hasattr(args, "flush_cache"):
|
||||||
|
args.flush_cache = False
|
||||||
|
|
||||||
return asyncio.run(
|
return asyncio.run(
|
||||||
benchmark(
|
benchmark(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
@@ -1362,6 +1373,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
profile=args.profile,
|
profile=args.profile,
|
||||||
pd_seperated=args.pd_seperated,
|
pd_seperated=args.pd_seperated,
|
||||||
|
flush_cache=args.flush_cache,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1543,6 +1555,11 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Benchmark PD disaggregation server",
|
help="Benchmark PD disaggregation server",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--flush-cache",
|
||||||
|
action="store_true",
|
||||||
|
help="Flush the cache before running the benchmark",
|
||||||
|
)
|
||||||
|
|
||||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user