Benchmark with Pytorch Profiler easily (#2110)

Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
bjmsong
2024-11-22 15:29:50 +08:00
committed by GitHub
parent dfec7fca06
commit ad30d5cf9a
2 changed files with 48 additions and 1 deletions

View File

@@ -388,6 +388,24 @@ async def async_request_gserver(
raise NotImplementedError()
async def async_request_profile(api_url: str) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
output = RequestFuncOutput()
try:
async with session.post(url=api_url) as response:
if response.status == 200:
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
return output
def get_model(pretrained_model_name_or_path: str) -> str:
if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
import huggingface_hub.constants
@@ -836,12 +854,14 @@ def calculate_metrics(
async def benchmark(
backend: str,
api_url: str,
base_url: str,
model_id: str,
tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]],
request_rate: float,
disable_tqdm: bool,
extra_request_body: Dict[str, Any],
profile: bool,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -869,6 +889,14 @@ async def benchmark(
time.sleep(1.5)
if profile:
print("Starting profiler...")
profile_output = await async_request_profile(
api_url=base_url + "/start_profile"
)
if profile_output.success:
print("Profiler started")
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
benchmark_start_time = time.perf_counter()
@@ -890,6 +918,12 @@ async def benchmark(
)
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
if profile:
print("Stopping profiler...")
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
if profile_output.success:
print("Profiler stopped")
if pbar is not None:
pbar.close()
@@ -1114,6 +1148,9 @@ def run_benchmark(args_: argparse.Namespace):
if args.base_url
else f"http://{args.host}:{args.port}/v1/models/model:predict"
)
base_url = (
f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
)
# Get model name
if args.model is None:
@@ -1159,12 +1196,14 @@ def run_benchmark(args_: argparse.Namespace):
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
extra_request_body=extra_request_body,
profile=args.profile,
)
)
else:
@@ -1176,12 +1215,14 @@ def run_benchmark(args_: argparse.Namespace):
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=rate,
disable_tqdm=args.disable_tqdm,
extra_request_body=extra_request_body,
profile=args.profile,
)
)
@@ -1355,6 +1396,11 @@ if __name__ == "__main__":
type=str,
help="Path to load previously generated input data",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
args = parser.parse_args()
run_benchmark(args)

View File

@@ -564,6 +564,7 @@ def run_bench_serving(
disable_stream=disable_stream,
disable_ignore_eos=False,
extra_request_body=None,
profile=None,
)
try: