feat: add concurrency evaluation logic in mmmu benchmark (#5782)
This commit is contained in:
@@ -8,13 +8,15 @@ Host the VLM:
|
|||||||
python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
|
python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
|
||||||
```
|
```
|
||||||
|
|
||||||
|
It's recommended to reduce the memory usage by appending something like `--mem-fraction-static 0.6` to the command above.
|
||||||
|
|
||||||
Benchmark:
|
Benchmark:
|
||||||
|
|
||||||
```
|
```
|
||||||
python benchmark/mmmu/bench_sglang.py --port 30000
|
python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
|
||||||
```
|
```
|
||||||
|
|
||||||
It's recommended to reduce the memory usage by appending something ike `--mem-fraction-static 0.6` to the command above.
|
You can adjust the `--concurrency` to control the number of concurrent OpenAI calls.
|
||||||
|
|
||||||
### Evaluate hf
|
### Evaluate hf
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Bench the sglang-hosted vLM with benchmark MMMU
|
|||||||
Usage:
|
Usage:
|
||||||
Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
|
Host the VLM: python -m sglang.launch_server --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl --port 30000
|
||||||
|
|
||||||
Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000
|
Benchmark: python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
|
||||||
|
|
||||||
The eval output will be logged
|
The eval output will be logged
|
||||||
"""
|
"""
|
||||||
@@ -15,7 +15,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import openai
|
import openai
|
||||||
@@ -65,22 +65,62 @@ async def async_request_profile(api_url: str) -> RequestFuncOutput:
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
async def eval_mmmu(args):
|
def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
|
||||||
|
"""Split the prompt into prefix and suffix."""
|
||||||
|
prefix = prompt.split("<")[0]
|
||||||
|
suffix = prompt.split(">", 1)[1]
|
||||||
|
return prefix, suffix
|
||||||
|
|
||||||
|
|
||||||
|
async def process_sample(
|
||||||
|
client: Any, sample: dict, sampling_params: dict
|
||||||
|
) -> Tuple[dict, str]:
|
||||||
|
"""Send a single sample to the LLM and return (sample, response)."""
|
||||||
|
prompt = sample["final_input_prompt"]
|
||||||
|
prefix, suffix = _get_prefix_suffix(prompt)
|
||||||
|
image = sample["image"]
|
||||||
|
assert image is not None
|
||||||
|
image_path = sample["image_path"]
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": prefix},
|
||||||
|
{"type": "image_url", "image_url": {"url": image_path}},
|
||||||
|
{"type": "text", "text": suffix},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_completion_tokens=sampling_params["max_new_tokens"],
|
||||||
|
max_tokens=sampling_params["max_new_tokens"],
|
||||||
|
)
|
||||||
|
return sample, response.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
|
async def process_sample_with_semaphore(
|
||||||
|
semaphore: asyncio.Semaphore, client: Any, sample: dict, sampling_params: dict
|
||||||
|
) -> Tuple[dict, str]:
|
||||||
|
"""Wrap process_sample with a semaphore for concurrency control."""
|
||||||
|
async with semaphore:
|
||||||
|
return await process_sample(client, sample, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
async def eval_mmmu(args) -> None:
|
||||||
|
"""Main evaluation loop with concurrency control."""
|
||||||
eval_args = EvalArgs.from_cli_args(args)
|
eval_args = EvalArgs.from_cli_args(args)
|
||||||
|
|
||||||
out_samples = dict()
|
|
||||||
|
|
||||||
sampling_params = get_sampling_params(eval_args)
|
sampling_params = get_sampling_params(eval_args)
|
||||||
|
|
||||||
samples = prepare_samples(eval_args)
|
samples = prepare_samples(eval_args)
|
||||||
|
|
||||||
answer_dict = {}
|
answer_dict = {}
|
||||||
|
out_samples = {}
|
||||||
# had to use an openai server, since SglImage doesn't support image data
|
client = openai.AsyncOpenAI(
|
||||||
base_url = f"http://127.0.0.1:{args.port}"
|
api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1"
|
||||||
client = openai.Client(api_key="sk", base_url=f"{base_url}/v1")
|
)
|
||||||
|
semaphore = asyncio.Semaphore(args.concurrency)
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
base_url = f"http://127.0.0.1:{args.port}"
|
||||||
|
|
||||||
if args.profile:
|
if args.profile:
|
||||||
print("Starting profiler...")
|
print("Starting profiler...")
|
||||||
@@ -90,44 +130,15 @@ async def eval_mmmu(args):
|
|||||||
if profile_output.success:
|
if profile_output.success:
|
||||||
print("Profiler started")
|
print("Profiler started")
|
||||||
|
|
||||||
if args.profile:
|
|
||||||
samples = samples[: args.profile_number]
|
samples = samples[: args.profile_number]
|
||||||
|
|
||||||
for i, sample in enumerate(tqdm(samples)):
|
tasks = [
|
||||||
prompt = sample["final_input_prompt"]
|
process_sample_with_semaphore(semaphore, client, sample, sampling_params)
|
||||||
prefix = prompt.split("<")[0]
|
for sample in samples
|
||||||
suffix = prompt.split(">")[1]
|
]
|
||||||
image = sample["image"]
|
|
||||||
assert image is not None
|
|
||||||
image_path = sample["image_path"]
|
|
||||||
# TODO: batch
|
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
|
||||||
model="default",
|
sample, response = await coro
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": prefix,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": image_path},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": suffix,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
temperature=0,
|
|
||||||
max_completion_tokens=sampling_params["max_new_tokens"],
|
|
||||||
max_tokens=sampling_params["max_new_tokens"],
|
|
||||||
)
|
|
||||||
response = response.choices[0].message.content
|
|
||||||
process_result(response, sample, answer_dict, out_samples)
|
process_result(response, sample, answer_dict, out_samples)
|
||||||
|
|
||||||
if args.profile:
|
if args.profile:
|
||||||
@@ -137,15 +148,22 @@ async def eval_mmmu(args):
|
|||||||
print("Profiler stopped")
|
print("Profiler stopped")
|
||||||
|
|
||||||
print(f"Benchmark time: {time.time() - start}")
|
print(f"Benchmark time: {time.time() - start}")
|
||||||
|
|
||||||
args.output_path = f"./val_sglang.json"
|
args.output_path = f"./val_sglang.json"
|
||||||
save_json(args.output_path, out_samples)
|
save_json(args.output_path, out_samples)
|
||||||
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
|
eval_result(model_answer_path=args.output_path, answer_dict=answer_dict)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
EvalArgs.add_cli_args(parser)
|
EvalArgs.add_cli_args(parser)
|
||||||
args = add_common_sglang_args_and_parse(parser)
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
args = parser.parse_args()
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
asyncio.run(eval_mmmu(args))
|
asyncio.run(eval_mmmu(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class EvalArgs:
|
|||||||
extra_request_body: Optional[str] = None
|
extra_request_body: Optional[str] = None
|
||||||
profile: bool = False
|
profile: bool = False
|
||||||
profile_number: int = 5
|
profile_number: int = 5
|
||||||
|
concurrency: int = 1
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
@@ -73,6 +74,7 @@ class EvalArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--profile-number", type=int, default=EvalArgs.profile_number
|
"--profile-number", type=int, default=EvalArgs.profile_number
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--concurrency", type=int, default=EvalArgs.concurrency)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
Reference in New Issue
Block a user