feat: add concurrency evaluation logic in mmmu benchmark (#5782)

This commit is contained in:
XinyuanTong
2025-05-01 18:20:08 -07:00
committed by GitHub
parent d33955d28a
commit c5645e928f
3 changed files with 75 additions and 53 deletions

View File

@@ -8,13 +8,15 @@ Host the VLM:
python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000 python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
``` ```
It's recommended to reduce the memory usage by appending something like `--mem-fraction-static 0.6` to the command above.
Benchmark: Benchmark:
``` ```
python benchmark/mmmu/bench_sglang.py --port 30000 python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
``` ```
It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above. You can adjust the `--concurrency` to control the number of concurrent OpenAI calls.
### Evaluate hf ### Evaluate hf

View File

@@ -4,7 +4,7 @@ Bench the sglang-hosted vLM with benchmark MMMU
Usage: Usage:
Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000 Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
The eval output will be logged The eval output will be logged
""" """
@@ -15,7 +15,7 @@ import sys
import time import time
import traceback import traceback
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List from typing import Any, List, Tuple
import aiohttp import aiohttp
import openai import openai
@@ -65,22 +65,62 @@ async def async_request_profile(api_url: str) -> RequestFuncOutput:
return output return output
async def eval_mmmu(args): def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
"""Split the prompt into prefix and suffix."""
prefix = prompt.split("<")[0]
suffix = prompt.split(">", 1)[1]
return prefix, suffix
async def process_sample(
client: Any, sample: dict, sampling_params: dict
) -> Tuple[dict, str]:
"""Send a single sample to the LLM and return (sample, response)."""
prompt = sample["final_input_prompt"]
prefix, suffix = _get_prefix_suffix(prompt)
image = sample["image"]
assert image is not None
image_path = sample["image_path"]
response = await client.chat.completions.create(
model="default",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prefix},
{"type": "image_url", "image_url": {"url": image_path}},
{"type": "text", "text": suffix},
],
}
],
temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"],
)
return sample, response.choices[0].message.content
async def process_sample_with_semaphore(
semaphore: asyncio.Semaphore, client: Any, sample: dict, sampling_params: dict
) -> Tuple[dict, str]:
"""Wrap process_sample with a semaphore for concurrency control."""
async with semaphore:
return await process_sample(client, sample, sampling_params)
async def eval_mmmu(args) -> None:
"""Main evaluation loop with concurrency control."""
eval_args = EvalArgs.from_cli_args(args) eval_args = EvalArgs.from_cli_args(args)
out_samples = dict()
sampling_params = get_sampling_params(eval_args) sampling_params = get_sampling_params(eval_args)
samples = prepare_samples(eval_args) samples = prepare_samples(eval_args)
answer_dict = {} answer_dict = {}
out_samples = {}
# had to use an openai server, since SglImage doesn't support image data client = openai.AsyncOpenAI(
base_url = f"http://127.0.0.1:{args.port}" api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1"
client = openai.Client(api_key="sk", base_url=f"{base_url}/v1") )
semaphore = asyncio.Semaphore(args.concurrency)
start = time.time() start = time.time()
base_url = f"http://127.0.0.1:{args.port}"
if args.profile: if args.profile:
print("Starting profiler...") print("Starting profiler...")
@@ -90,44 +130,15 @@ async def eval_mmmu(args):
if profile_output.success: if profile_output.success:
print("Profiler started") print("Profiler started")
if args.profile:
samples = samples[: args.profile_number] samples = samples[: args.profile_number]
for i, sample in enumerate(tqdm(samples)): tasks = [
prompt = sample["final_input_prompt"] process_sample_with_semaphore(semaphore, client, sample, sampling_params)
prefix = prompt.split("<")[0] for sample in samples
suffix = prompt.split(">")[1] ]
image = sample["image"]
assert image is not None
image_path = sample["image_path"]
# TODO: batch
response = client.chat.completions.create( for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
model="default", sample, response = await coro
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": prefix,
},
{
"type": "image_url",
"image_url": {"url": image_path},
},
{
"type": "text",
"text": suffix,
},
],
}
],
temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"],
)
response = response.choices[0].message.content
process_result(response, sample, answer_dict, out_samples) process_result(response, sample, answer_dict, out_samples)
if args.profile: if args.profile:
@@ -137,15 +148,22 @@ async def eval_mmmu(args):
print("Profiler stopped") print("Profiler stopped")
print(f"Benchmark time: {time.time() - start}") print(f"Benchmark time: {time.time() - start}")
args.output_path = f"./val_sglang.json" args.output_path = f"./val_sglang.json"
save_json(args.output_path, out_samples) save_json(args.output_path, out_samples)
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict) eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
if __name__ == "__main__": def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
EvalArgs.add_cli_args(parser) EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser) args = add_common_sglang_args_and_parse(parser)
args = parser.parse_args() return args
def main():
args = parse_args()
asyncio.run(eval_mmmu(args)) asyncio.run(eval_mmmu(args))
if __name__ == "__main__":
main()

View File

@@ -35,6 +35,7 @@ class EvalArgs:
extra_request_body: Optional[str] = None extra_request_body: Optional[str] = None
profile: bool = False profile: bool = False
profile_number: int = 5 profile_number: int = 5
concurrency: int = 1
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
@@ -73,6 +74,7 @@ class EvalArgs:
parser.add_argument( parser.add_argument(
"--profile-number", type=int, default=EvalArgs.profile_number "--profile-number", type=int, default=EvalArgs.profile_number
) )
parser.add_argument("--concurrency", type=int, default=EvalArgs.concurrency)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):