Support LoRA in MMMU benchmark script. (#7218)

This commit is contained in:
Lifu Huang
2025-06-15 21:17:57 -07:00
committed by GitHub
parent 3c2274fbee
commit e07d064729
3 changed files with 54 additions and 12 deletions

View File

@@ -15,7 +15,7 @@ import sys
import time
import traceback
from dataclasses import dataclass, field
from typing import Any, List, Tuple
from typing import Any, List, Optional, Tuple
import aiohttp
import openai
@@ -73,7 +73,7 @@ def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
async def process_sample(
client: Any, sample: dict, sampling_params: dict
client: Any, sample: dict, sampling_params: dict, lora_path: Optional[str] = None
) -> Tuple[dict, str]:
"""Send a single sample to the LLM and return (sample, response)."""
prompt = sample["final_input_prompt"]
@@ -81,6 +81,7 @@ async def process_sample(
image = sample["image"]
assert image is not None
image_path = sample["image_path"]
extra_body = None if lora_path is None else {"lora_path": lora_path}
response = await client.chat.completions.create(
model="default",
messages=[
@@ -96,16 +97,21 @@ async def process_sample(
temperature=0,
max_completion_tokens=sampling_params["max_new_tokens"],
max_tokens=sampling_params["max_new_tokens"],
extra_body=extra_body,
)
return sample, response.choices[0].message.content
async def process_sample_with_semaphore(
semaphore: asyncio.Semaphore, client: Any, sample: dict, sampling_params: dict
semaphore: asyncio.Semaphore,
client: Any,
sample: dict,
sampling_params: dict,
lora_path: Optional[str] = None,
) -> Tuple[dict, str]:
"""Wrap process_sample with a semaphore for concurrency control."""
async with semaphore:
return await process_sample(client, sample, sampling_params)
return await process_sample(client, sample, sampling_params, lora_path)
async def eval_mmmu(args) -> None:
@@ -113,6 +119,7 @@ async def eval_mmmu(args) -> None:
eval_args = EvalArgs.from_cli_args(args)
sampling_params = get_sampling_params(eval_args)
samples = prepare_samples(eval_args)
lora_path = eval_args.lora_path
answer_dict = {}
out_samples = {}
client = openai.AsyncOpenAI(
@@ -133,7 +140,9 @@ async def eval_mmmu(args) -> None:
samples = samples[: args.profile_number]
tasks = [
process_sample_with_semaphore(semaphore, client, sample, sampling_params)
process_sample_with_semaphore(
semaphore, client, sample, sampling_params, lora_path
)
for sample in samples
]