diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 2c33c6ac5..5353ec138 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -388,6 +388,24 @@ async def async_request_gserver( raise NotImplementedError() +async def async_request_profile(api_url: str) -> RequestFuncOutput: + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + output = RequestFuncOutput() + try: + async with session.post(url=api_url) as response: + if response.status == 200: + output.success = True + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + return output + + def get_model(pretrained_model_name_or_path: str) -> str: if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true": import huggingface_hub.constants @@ -836,12 +854,14 @@ def calculate_metrics( async def benchmark( backend: str, api_url: str, + base_url: str, model_id: str, tokenizer: PreTrainedTokenizerBase, input_requests: List[Tuple[str, int, int]], request_rate: float, disable_tqdm: bool, extra_request_body: Dict[str, Any], + profile: bool, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -869,6 +889,14 @@ async def benchmark( time.sleep(1.5) + if profile: + print("Starting profiler...") + profile_output = await async_request_profile( + api_url=base_url + "/start_profile" + ) + if profile_output.success: + print("Profiler started") + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) benchmark_start_time = time.perf_counter() @@ -890,6 +918,12 @@ async def benchmark( ) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + if profile: + print("Stopping profiler...") + profile_output = await async_request_profile(api_url=base_url + "/stop_profile") + if profile_output.success: + print("Profiler stopped") + if pbar is not None: pbar.close() @@ -1114,6 +1148,9 @@ def run_benchmark(args_: argparse.Namespace): if args.base_url else f"http://{args.host}:{args.port}/v1/models/model:predict" ) + base_url = ( + f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url + ) # Get model name if args.model is None: @@ -1159,12 +1196,14 @@ def run_benchmark(args_: argparse.Namespace): benchmark( backend=backend, api_url=api_url, + base_url=base_url, model_id=model_id, tokenizer=tokenizer, input_requests=input_requests, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, extra_request_body=extra_request_body, + profile=args.profile, ) ) else: @@ -1176,12 +1215,14 @@ def run_benchmark(args_: argparse.Namespace): benchmark( backend=backend, api_url=api_url, + base_url=base_url, model_id=model_id, tokenizer=tokenizer, input_requests=input_requests, request_rate=rate, disable_tqdm=args.disable_tqdm, extra_request_body=extra_request_body, + profile=args.profile, ) ) @@ -1355,6 +1396,11 @@ if __name__ == "__main__": type=str, help="Path to load previously generated input data", ) - + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) args = parser.parse_args() run_benchmark(args) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 3f991b39b..5aba58d7a 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -564,6 +564,7 @@ def run_bench_serving( disable_stream=disable_stream, disable_ignore_eos=False, extra_request_body=None, + profile=None, ) try: