Support LoRA in MMMU benchmark script. (#7218)

This commit is contained in:
Lifu Huang
2025-06-15 21:17:57 -07:00
committed by GitHub
parent 3c2274fbee
commit e07d064729
3 changed files with 54 additions and 12 deletions

View File

@@ -18,6 +18,15 @@ python benchmark/mmmu/bench_sglang.py --port 30000 --concurrency 16
You can adjust the `--concurrency` to control the number of concurrent OpenAI calls. You can adjust the `--concurrency` to control the number of concurrent OpenAI calls.
You can use `--lora-path` to specify the LoRA adapter to apply during benchmarking. E.g.,
```
# Launch server with LoRA enabled
python -m sglang.launch_server --model-path microsoft/Phi-4-multimodal-instruct --port 30000 --trust-remote-code --disable-radix-cache --lora-paths vision=<LoRA path>
# Apply LoRA adapter during inferencing
python -m benchmark/mmmu/bench_sglang.py --concurrency 8 --lora-path vision
```
### Evaluate hf ### Evaluate hf
``` ```

View File

@@ -15,7 +15,7 @@ import sys
import time import time
import traceback import traceback
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, List, Tuple from typing import Any, List, Optional, Tuple
import aiohttp import aiohttp
import openai import openai
@@ -73,7 +73,7 @@ def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
async def process_sample( async def process_sample(
client: Any, sample: dict, sampling_params: dict client: Any, sample: dict, sampling_params: dict, lora_path: Optional[str] = None
) -> Tuple[dict, str]: ) -> Tuple[dict, str]:
"""Send a single sample to the LLM and return (sample, response).""" """Send a single sample to the LLM and return (sample, response)."""
prompt = sample["final_input_prompt"] prompt = sample["final_input_prompt"]
@@ -81,6 +81,7 @@ async def process_sample(
image = sample["image"] image = sample["image"]
assert image is not None assert image is not None
image_path = sample["image_path"] image_path = sample["image_path"]
extra_body = None if lora_path is None else {"lora_path": lora_path}
response = await client.chat.completions.create( response = await client.chat.completions.create(
model="default", model="default",
messages=[ messages=[
@@ -96,16 +97,21 @@ async def process_sample(
temperature=0, temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"], max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"], max_tokens=sampling_params["max_new_tokens"],
extra_body=extra_body,
) )
return sample, response.choices[0].message.content return sample, response.choices[0].message.content
async def process_sample_with_semaphore( async def process_sample_with_semaphore(
semaphore: asyncio.Semaphore, client: Any, sample: dict, sampling_params: dict semaphore: asyncio.Semaphore,
client: Any,
sample: dict,
sampling_params: dict,
lora_path: Optional[str] = None,
) -> Tuple[dict, str]: ) -> Tuple[dict, str]:
"""Wrap process_sample with a semaphore for concurrency control.""" """Wrap process_sample with a semaphore for concurrency control."""
async with semaphore: async with semaphore:
return await process_sample(client, sample, sampling_params) return await process_sample(client, sample, sampling_params, lora_path)
async def eval_mmmu(args) -> None: async def eval_mmmu(args) -> None:
@@ -113,6 +119,7 @@ async def eval_mmmu(args) -> None:
eval_args = EvalArgs.from_cli_args(args) eval_args = EvalArgs.from_cli_args(args)
sampling_params = get_sampling_params(eval_args) sampling_params = get_sampling_params(eval_args)
samples = prepare_samples(eval_args) samples = prepare_samples(eval_args)
lora_path = eval_args.lora_path
answer_dict = {} answer_dict = {}
out_samples = {} out_samples = {}
client = openai.AsyncOpenAI( client = openai.AsyncOpenAI(
@@ -133,7 +140,9 @@ async def eval_mmmu(args) -> None:
samples = samples[: args.profile_number] samples = samples[: args.profile_number]
tasks = [ tasks = [
process_sample_with_semaphore(semaphore, client, sample, sampling_params) process_sample_with_semaphore(
semaphore, client, sample, sampling_params, lora_path
)
for sample in samples for sample in samples
] ]

View File

@@ -36,17 +36,22 @@ class EvalArgs:
profile: bool = False profile: bool = False
profile_number: int = 5 profile_number: int = 5
concurrency: int = 1 concurrency: int = 1
lora_path: Optional[str] = None
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--result-filename", type=str, default=EvalArgs.result_filename "--result-filename",
type=str,
default=EvalArgs.result_filename,
help="The filename to save the evaluation results.",
) )
parser.add_argument( parser.add_argument(
"--image-pixels-limit", type=int, default=EvalArgs.image_pixels_limit "--image-pixels-limit",
type=int,
default=EvalArgs.image_pixels_limit,
help="The maximum number of pixels allowed for an image. If an image exceeds this limit, it will be skipped during evaluation.",
) )
parser.add_argument( parser.add_argument(
"--dataset-path", "--dataset-path",
type=str, type=str,
@@ -59,7 +64,12 @@ class EvalArgs:
type=str, type=str,
help="The path to the prompt format of mmmu. If not, a default format llava_config.yaml will be used", help="The path to the prompt format of mmmu. If not, a default format llava_config.yaml will be used",
) )
parser.add_argument("--split", type=str, default=EvalArgs.split) parser.add_argument(
"--split",
type=str,
default=EvalArgs.split,
help='Split of the dataset to use for evaluation. Default is "validation".',
)
parser.add_argument( parser.add_argument(
"--extra-request-body", "--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}', metavar='{"key1": "value1", "key2": "value2"}',
@@ -72,9 +82,23 @@ class EvalArgs:
"--profile", action="store_true", help="enable mmmu profile" "--profile", action="store_true", help="enable mmmu profile"
) )
parser.add_argument( parser.add_argument(
"--profile-number", type=int, default=EvalArgs.profile_number "--profile-number",
type=int,
default=EvalArgs.profile_number,
help="Number of samples to profile. If not set, will profile all samples.",
)
parser.add_argument(
"--concurrency",
type=int,
default=EvalArgs.concurrency,
help="Number of concurrent requests to make during evaluation. Default is 1, which means no concurrency.",
)
parser.add_argument(
"--lora-path",
type=str,
default=EvalArgs.lora_path,
help="Specify the LoRA path to use for evaluation. If specified, the value will be specified in the body of every request as `lora-path`.",
) )
parser.add_argument("--concurrency", type=int, default=EvalArgs.concurrency)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):