diff --git a/benchmark/mmmu/README.md b/benchmark/mmmu/README.md index 604f8f27e..e39bdd3c4 100644 --- a/benchmark/mmmu/README.md +++ b/benchmark/mmmu/README.md @@ -8,13 +8,15 @@ Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000 ``` +It's recommended to reduce the memory usage by appending something like `--mem-fraction-static 0.6` to the command above. + Benchmark: ``` -python benchmark/mmmu/bench_sglang.py --port 30000 +python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16 ``` -It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above. +You can adjust the `--concurrency` to control the number of concurrent OpenAI calls. ### Evaluate hf diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index 55a7b1eaa..b2a2e2acd 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -4,7 +4,7 @@ Bench the sglang-hosted vLM with benchmark MMMU Usage: Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000 - Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 + Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16 The eval output will be logged """ @@ -15,7 +15,7 @@ import sys import time import traceback from dataclasses import dataclass, field -from typing import List +from typing import Any, List, Tuple import aiohttp import openai @@ -65,22 +65,62 @@ async def async_request_profile(api_url: str) -> RequestFuncOutput: return output -async def eval_mmmu(args): +def _get_prefix_suffix(prompt: str) -> Tuple[str, str]: + """Split the prompt into prefix and suffix.""" + prefix = prompt.split("<")[0] + suffix = prompt.split(">", 1)[1] + return prefix, suffix + + +async def process_sample( + client: Any, sample: dict, sampling_params: dict +) -> Tuple[dict, str]: + """Send a single sample to the LLM and return (sample, response).""" + prompt = sample["final_input_prompt"] + prefix, suffix = _get_prefix_suffix(prompt) + image = sample["image"] + assert image is not None + image_path = sample["image_path"] + response = await client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prefix}, + {"type": "image_url", "image_url": {"url": image_path}}, + {"type": "text", "text": suffix}, + ], + } + ], + temperature=0, + max_completion_tokens=sampling_params["max_new_tokens"], + max_tokens=sampling_params["max_new_tokens"], + ) + return sample, response.choices[0].message.content + + +async def process_sample_with_semaphore( + semaphore: asyncio.Semaphore, client: Any, sample: dict, sampling_params: dict +) -> Tuple[dict, str]: + """Wrap process_sample with a semaphore for concurrency control.""" + async with semaphore: + return await process_sample(client, sample, sampling_params) + + +async def eval_mmmu(args) -> None: + """Main evaluation loop with concurrency control.""" eval_args = EvalArgs.from_cli_args(args) - - out_samples = dict() - sampling_params = get_sampling_params(eval_args) - samples = prepare_samples(eval_args) - answer_dict = {} - - # had to use an openai server, since SglImage doesn't support image data - base_url = f"http://127.0.0.1:{args.port}" - client = openai.Client(api_key="sk", base_url=f"{base_url}/v1") - + out_samples = {} + client = openai.AsyncOpenAI( + api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1" + ) + semaphore = asyncio.Semaphore(args.concurrency) start = time.time() + base_url = f"http://127.0.0.1:{args.port}" if args.profile: print("Starting profiler...") @@ -90,44 +130,15 @@ async def eval_mmmu(args): if profile_output.success: print("Profiler started") - if args.profile: samples = samples[: args.profile_number] - for i, sample in enumerate(tqdm(samples)): - prompt = sample["final_input_prompt"] - prefix = prompt.split("<")[0] - suffix = prompt.split(">")[1] - image = sample["image"] - assert image is not None - image_path = sample["image_path"] - # TODO: batch + tasks = [ + process_sample_with_semaphore(semaphore, client, sample, sampling_params) + for sample in samples + ] - response = client.chat.completions.create( - model="default", - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prefix, - }, - { - "type": "image_url", - "image_url": {"url": image_path}, - }, - { - "type": "text", - "text": suffix, - }, - ], - } - ], - temperature=0, - max_completion_tokens=sampling_params["max_new_tokens"], - max_tokens=sampling_params["max_new_tokens"], - ) - response = response.choices[0].message.content + for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)): + sample, response = await coro process_result(response, sample, answer_dict, out_samples) if args.profile: @@ -137,15 +148,22 @@ async def eval_mmmu(args): print("Profiler stopped") print(f"Benchmark time: {time.time() - start}") - args.output_path = f"./val_sglang.json" save_json(args.output_path, out_samples) eval_result(model_answer_path=args.output_path, answer_dict=answer_dict) -if __name__ == "__main__": +def parse_args(): parser = argparse.ArgumentParser() EvalArgs.add_cli_args(parser) args = add_common_sglang_args_and_parse(parser) - args = parser.parse_args() + return args + + +def main(): + args = parse_args() asyncio.run(eval_mmmu(args)) + + +if __name__ == "__main__": + main() diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index 1a7db250e..a0960f9e0 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -35,6 +35,7 @@ class EvalArgs: extra_request_body: Optional[str] = None profile: bool = False profile_number: int = 5 + concurrency: int = 1 @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -73,6 +74,7 @@ class EvalArgs: parser.add_argument( "--profile-number", type=int, default=EvalArgs.profile_number ) + parser.add_argument("--concurrency", type=int, default=EvalArgs.concurrency) @classmethod def from_cli_args(cls, args: argparse.Namespace):