Benchmark with Pytorch Profiler easily (#2110)
Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
@@ -388,6 +388,24 @@ async def async_request_gserver(
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def async_request_profile(api_url: str) -> RequestFuncOutput:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
output = RequestFuncOutput()
|
||||
try:
|
||||
async with session.post(url=api_url) as response:
|
||||
if response.status == 200:
|
||||
output.success = True
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
|
||||
import huggingface_hub.constants
|
||||
@@ -836,12 +854,14 @@ def calculate_metrics(
|
||||
async def benchmark(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
base_url: str,
|
||||
model_id: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
request_rate: float,
|
||||
disable_tqdm: bool,
|
||||
extra_request_body: Dict[str, Any],
|
||||
profile: bool,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@@ -869,6 +889,14 @@ async def benchmark(
|
||||
|
||||
time.sleep(1.5)
|
||||
|
||||
if profile:
|
||||
print("Starting profiler...")
|
||||
profile_output = await async_request_profile(
|
||||
api_url=base_url + "/start_profile"
|
||||
)
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
@@ -890,6 +918,12 @@ async def benchmark(
|
||||
)
|
||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if profile:
|
||||
print("Stopping profiler...")
|
||||
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
|
||||
if profile_output.success:
|
||||
print("Profiler stopped")
|
||||
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
@@ -1114,6 +1148,9 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if args.base_url
|
||||
else f"http://{args.host}:{args.port}/v1/models/model:predict"
|
||||
)
|
||||
base_url = (
|
||||
f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
|
||||
)
|
||||
|
||||
# Get model name
|
||||
if args.model is None:
|
||||
@@ -1159,12 +1196,14 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -1176,12 +1215,14 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1355,6 +1396,11 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
help="Path to load previously generated input data",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
run_benchmark(args)
|
||||
|
||||
@@ -564,6 +564,7 @@ def run_bench_serving(
|
||||
disable_stream=disable_stream,
|
||||
disable_ignore_eos=False,
|
||||
extra_request_body=None,
|
||||
profile=None,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user