feat: add concurrency evaluation logic in mmmu benchmark (#5782)
This commit is contained in:
@@ -8,13 +8,15 @@ Host the VLM:
|
||||
python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
|
||||
```
|
||||
|
||||
It's recommended to reduce the memory usage by appending something like `--mem-fraction-static 0.6` to the command above.
|
||||
|
||||
Benchmark:
|
||||
|
||||
```
|
||||
python benchmark/mmmu/bench_sglang.py --port 30000
|
||||
python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
|
||||
```
|
||||
|
||||
It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above.
|
||||
You can adjust the `--concurrency` to control the number of concurrent OpenAI calls.
|
||||
|
||||
### Evaluate hf
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ Bench the sglang-hosted vLM with benchmark MMMU
|
||||
Usage:
|
||||
Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
|
||||
|
||||
Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000
|
||||
Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
|
||||
|
||||
The eval output will be logged
|
||||
"""
|
||||
@@ -15,7 +15,7 @@ import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
@@ -65,22 +65,62 @@ async def async_request_profile(api_url: str) -> RequestFuncOutput:
|
||||
return output
|
||||
|
||||
|
||||
async def eval_mmmu(args):
|
||||
def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
|
||||
"""Split the prompt into prefix and suffix."""
|
||||
prefix = prompt.split("<")[0]
|
||||
suffix = prompt.split(">", 1)[1]
|
||||
return prefix, suffix
|
||||
|
||||
|
||||
async def process_sample(
|
||||
client: Any, sample: dict, sampling_params: dict
|
||||
) -> Tuple[dict, str]:
|
||||
"""Send a single sample to the LLM and return (sample, response)."""
|
||||
prompt = sample["final_input_prompt"]
|
||||
prefix, suffix = _get_prefix_suffix(prompt)
|
||||
image = sample["image"]
|
||||
assert image is not None
|
||||
image_path = sample["image_path"]
|
||||
response = await client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": prefix},
|
||||
{"type": "image_url", "image_url": {"url": image_path}},
|
||||
{"type": "text", "text": suffix},
|
||||
],
|
||||
}
|
||||
],
|
||||
temperature=0,
|
||||
max_completion_tokens=sampling_params["max_new_tokens"],
|
||||
max_tokens=sampling_params["max_new_tokens"],
|
||||
)
|
||||
return sample, response.choices[0].message.content
|
||||
|
||||
|
||||
async def process_sample_with_semaphore(
|
||||
semaphore: asyncio.Semaphore, client: Any, sample: dict, sampling_params: dict
|
||||
) -> Tuple[dict, str]:
|
||||
"""Wrap process_sample with a semaphore for concurrency control."""
|
||||
async with semaphore:
|
||||
return await process_sample(client, sample, sampling_params)
|
||||
|
||||
|
||||
async def eval_mmmu(args) -> None:
|
||||
"""Main evaluation loop with concurrency control."""
|
||||
eval_args = EvalArgs.from_cli_args(args)
|
||||
|
||||
out_samples = dict()
|
||||
|
||||
sampling_params = get_sampling_params(eval_args)
|
||||
|
||||
samples = prepare_samples(eval_args)
|
||||
|
||||
answer_dict = {}
|
||||
|
||||
# had to use an openai server, since SglImage doesn't support image data
|
||||
base_url = f"http://127.0.0.1:{args.port}"
|
||||
client = openai.Client(api_key="sk", base_url=f"{base_url}/v1")
|
||||
|
||||
out_samples = {}
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1"
|
||||
)
|
||||
semaphore = asyncio.Semaphore(args.concurrency)
|
||||
start = time.time()
|
||||
base_url = f"http://127.0.0.1:{args.port}"
|
||||
|
||||
if args.profile:
|
||||
print("Starting profiler...")
|
||||
@@ -90,44 +130,15 @@ async def eval_mmmu(args):
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
|
||||
if args.profile:
|
||||
samples = samples[: args.profile_number]
|
||||
|
||||
for i, sample in enumerate(tqdm(samples)):
|
||||
prompt = sample["final_input_prompt"]
|
||||
prefix = prompt.split("<")[0]
|
||||
suffix = prompt.split(">")[1]
|
||||
image = sample["image"]
|
||||
assert image is not None
|
||||
image_path = sample["image_path"]
|
||||
# TODO: batch
|
||||
tasks = [
|
||||
process_sample_with_semaphore(semaphore, client, sample, sampling_params)
|
||||
for sample in samples
|
||||
]
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prefix,
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_path},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": suffix,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
temperature=0,
|
||||
max_completion_tokens=sampling_params["max_new_tokens"],
|
||||
max_tokens=sampling_params["max_new_tokens"],
|
||||
)
|
||||
response = response.choices[0].message.content
|
||||
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
|
||||
sample, response = await coro
|
||||
process_result(response, sample, answer_dict, out_samples)
|
||||
|
||||
if args.profile:
|
||||
@@ -137,15 +148,22 @@ async def eval_mmmu(args):
|
||||
print("Profiler stopped")
|
||||
|
||||
print(f"Benchmark time: {time.time() - start}")
|
||||
|
||||
args.output_path = f"./val_sglang.json"
|
||||
save_json(args.output_path, out_samples)
|
||||
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
EvalArgs.add_cli_args(parser)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
asyncio.run(eval_mmmu(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -35,6 +35,7 @@ class EvalArgs:
|
||||
extra_request_body: Optional[str] = None
|
||||
profile: bool = False
|
||||
profile_number: int = 5
|
||||
concurrency: int = 1
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -73,6 +74,7 @@ class EvalArgs:
|
||||
parser.add_argument(
|
||||
"--profile-number", type=int, default=EvalArgs.profile_number
|
||||
)
|
||||
parser.add_argument("--concurrency", type=int, default=EvalArgs.concurrency)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user