Files
sglang/benchmark/mmmu/bench_sglang.py
2025-04-29 23:48:27 -07:00

152 lines
4.5 KiB
Python

"""
Bench the sglang-hosted vLM with benchmark MMMU
Usage:
Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000
The eval output will be logged
"""
import argparse
import asyncio
import sys
import time
import traceback
from dataclasses import dataclass, field
from typing import List
import aiohttp
import openai
from data_utils import save_json
from eval_utils import (
EvalArgs,
eval_result,
get_sampling_params,
prepare_samples,
process_result,
)
from tqdm import tqdm
from sglang.test.test_utils import add_common_sglang_args_and_parse
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
@dataclass
class RequestFuncOutput:
generated_text: List[str] = field(default_factory=list)
prompt_len: List[int] = field(default_factory=list)
output_len: List[int] = field(default_factory=list)
latency: List[float] = field(default_factory=list)
ttft: List[float] = field(default_factory=list)
itl: List[float] = field(default_factory=list) # List of inter-token latencies
success: bool = False
error: str = ""
async def async_request_profile(api_url: str) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
output = RequestFuncOutput()
try:
async with session.post(url=api_url) as response:
if response.status == 200:
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
return output
async def eval_mmmu(args):
eval_args = EvalArgs.from_cli_args(args)
out_samples = dict()
sampling_params = get_sampling_params(eval_args)
samples = prepare_samples(eval_args)
answer_dict = {}
# had to use an openai server, since SglImage doesn't support image data
base_url = f"http://127.0.0.1:{args.port}"
client = openai.Client(api_key="sk", base_url=f"{base_url}/v1")
start = time.time()
if args.profile:
print("Starting profiler...")
profile_output = await async_request_profile(
api_url=f"{base_url}/start_profile"
)
if profile_output.success:
print("Profiler started")
if args.profile:
samples = samples[: args.profile_number]
for i, sample in enumerate(tqdm(samples)):
prompt = sample["final_input_prompt"]
prefix = prompt.split("<")[0]
suffix = prompt.split(">")[1]
image = sample["image"]
assert image is not None
image_path = sample["image_path"]
# TODO: batch
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prefix,
},
{
"type": "image_url",
"image_url": {"url": image_path},
},
{
"type": "text",
"text": suffix,
},
],
}
],
temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"],
)
response = response.choices[0].message.content
process_result(response, sample, answer_dict, out_samples)
if args.profile:
print("Stopping profiler...")
profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile")
if profile_output.success:
print("Profiler stopped")
print(f"Benchmark time: {time.time() - start}")
args.output_path = f"./val_sglang.json"
save_json(args.output_path, out_samples)
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser)
args = parser.parse_args()
asyncio.run(eval_mmmu(args))