Benchmark with Pytorch Profiler easily (#2110)
Co-authored-by: root <bjmsong@126.com>
This commit is contained in:
@@ -388,6 +388,24 @@ async def async_request_gserver(
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
async def async_request_profile(api_url: str) -> RequestFuncOutput:
|
||||||
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
try:
|
||||||
|
async with session.post(url=api_url) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
output.success = True
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
|
except Exception:
|
||||||
|
output.success = False
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
output.error = "".join(traceback.format_exception(*exc_info))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def get_model(pretrained_model_name_or_path: str) -> str:
|
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||||
if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
|
if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
|
||||||
import huggingface_hub.constants
|
import huggingface_hub.constants
|
||||||
@@ -836,12 +854,14 @@ def calculate_metrics(
|
|||||||
async def benchmark(
|
async def benchmark(
|
||||||
backend: str,
|
backend: str,
|
||||||
api_url: str,
|
api_url: str,
|
||||||
|
base_url: str,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
input_requests: List[Tuple[str, int, int]],
|
input_requests: List[Tuple[str, int, int]],
|
||||||
request_rate: float,
|
request_rate: float,
|
||||||
disable_tqdm: bool,
|
disable_tqdm: bool,
|
||||||
extra_request_body: Dict[str, Any],
|
extra_request_body: Dict[str, Any],
|
||||||
|
profile: bool,
|
||||||
):
|
):
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
@@ -869,6 +889,14 @@ async def benchmark(
|
|||||||
|
|
||||||
time.sleep(1.5)
|
time.sleep(1.5)
|
||||||
|
|
||||||
|
if profile:
|
||||||
|
print("Starting profiler...")
|
||||||
|
profile_output = await async_request_profile(
|
||||||
|
api_url=base_url + "/start_profile"
|
||||||
|
)
|
||||||
|
if profile_output.success:
|
||||||
|
print("Profiler started")
|
||||||
|
|
||||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
@@ -890,6 +918,12 @@ async def benchmark(
|
|||||||
)
|
)
|
||||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
if profile:
|
||||||
|
print("Stopping profiler...")
|
||||||
|
profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
|
||||||
|
if profile_output.success:
|
||||||
|
print("Profiler stopped")
|
||||||
|
|
||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|
||||||
@@ -1114,6 +1148,9 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
if args.base_url
|
if args.base_url
|
||||||
else f"http://{args.host}:{args.port}/v1/models/model:predict"
|
else f"http://{args.host}:{args.port}/v1/models/model:predict"
|
||||||
)
|
)
|
||||||
|
base_url = (
|
||||||
|
f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
|
||||||
|
)
|
||||||
|
|
||||||
# Get model name
|
# Get model name
|
||||||
if args.model is None:
|
if args.model is None:
|
||||||
@@ -1159,12 +1196,14 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
benchmark(
|
benchmark(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
|
base_url=base_url,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
input_requests=input_requests,
|
input_requests=input_requests,
|
||||||
request_rate=args.request_rate,
|
request_rate=args.request_rate,
|
||||||
disable_tqdm=args.disable_tqdm,
|
disable_tqdm=args.disable_tqdm,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
|
profile=args.profile,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -1176,12 +1215,14 @@ def run_benchmark(args_: argparse.Namespace):
|
|||||||
benchmark(
|
benchmark(
|
||||||
backend=backend,
|
backend=backend,
|
||||||
api_url=api_url,
|
api_url=api_url,
|
||||||
|
base_url=base_url,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
input_requests=input_requests,
|
input_requests=input_requests,
|
||||||
request_rate=rate,
|
request_rate=rate,
|
||||||
disable_tqdm=args.disable_tqdm,
|
disable_tqdm=args.disable_tqdm,
|
||||||
extra_request_body=extra_request_body,
|
extra_request_body=extra_request_body,
|
||||||
|
profile=args.profile,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1355,6 +1396,11 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
help="Path to load previously generated input data",
|
help="Path to load previously generated input data",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--profile",
|
||||||
|
action="store_true",
|
||||||
|
help="Use Torch Profiler. The endpoint must be launched with "
|
||||||
|
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
run_benchmark(args)
|
run_benchmark(args)
|
||||||
|
|||||||
@@ -564,6 +564,7 @@ def run_bench_serving(
|
|||||||
disable_stream=disable_stream,
|
disable_stream=disable_stream,
|
||||||
disable_ignore_eos=False,
|
disable_ignore_eos=False,
|
||||||
extra_request_body=None,
|
extra_request_body=None,
|
||||||
|
profile=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user