Files
sglang/benchmark/mmmu/bench_sglang.py

152 lines
4.5 KiB
Python
Raw Normal View History

"""
Bench the sglang-hosted vLM with benchmark MMMU
Usage:
Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000
The eval output will be logged
"""
import argparse
2025-04-30 14:48:27 +08:00
import asyncio
import sys
import time
2025-04-30 14:48:27 +08:00
import traceback
from dataclasses import dataclass, field
from typing import List
2025-04-30 14:48:27 +08:00
import aiohttp
import openai
from data_utils import save_json
from eval_utils import (
EvalArgs,
eval_result,
get_sampling_params,
prepare_samples,
process_result,
)
from tqdm import tqdm
from sglang.test.test_utils import add_common_sglang_args_and_parse
2025-04-30 14:48:27 +08:00
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
2025-04-30 14:48:27 +08:00
@dataclass
class RequestFuncOutput:
generated_text: List[str] = field(default_factory=list)
prompt_len: List[int] = field(default_factory=list)
output_len: List[int] = field(default_factory=list)
latency: List[float] = field(default_factory=list)
ttft: List[float] = field(default_factory=list)
itl: List[float] = field(default_factory=list) # List of inter-token latencies
success: bool = False
error: str = ""
async def async_request_profile(api_url: str) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
output = RequestFuncOutput()
try:
async with session.post(url=api_url) as response:
if response.status == 200:
output.success = True
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
return output
async def eval_mmmu(args):
eval_args = EvalArgs.from_cli_args(args)
out_samples = dict()
sampling_params = get_sampling_params(eval_args)
samples = prepare_samples(eval_args)
answer_dict = {}
# had to use an openai server, since SglImage doesn't support image data
2025-04-30 14:48:27 +08:00
base_url = f"http://127.0.0.1:{args.port}"
client = openai.Client(api_key="sk", base_url=f"{base_url}/v1")
start = time.time()
2025-04-30 14:48:27 +08:00
if args.profile:
print("Starting profiler...")
profile_output = await async_request_profile(
api_url=f"{base_url}/start_profile"
)
if profile_output.success:
print("Profiler started")
if args.profile:
samples = samples[: args.profile_number]
for i, sample in enumerate(tqdm(samples)):
prompt = sample["final_input_prompt"]
prefix = prompt.split("<")[0]
suffix = prompt.split(">")[1]
image = sample["image"]
assert image is not None
image_path = sample["image_path"]
# TODO: batch
2025-04-30 14:48:27 +08:00
response = client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prefix,
},
{
"type": "image_url",
"image_url": {"url": image_path},
},
{
"type": "text",
"text": suffix,
},
],
}
],
temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"],
)
response = response.choices[0].message.content
process_result(response, sample, answer_dict, out_samples)
2025-04-30 14:48:27 +08:00
if args.profile:
print("Stopping profiler...")
profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile")
if profile_output.success:
print("Profiler stopped")
print(f"Benchmark time: {time.time() - start}")
args.output_path = f"./val_sglang.json"
save_json(args.output_path, out_samples)
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser)
args = parser.parse_args()
2025-04-30 14:48:27 +08:00
asyncio.run(eval_mmmu(args))