Support LoRA in MMMU benchmark script. (#7218)
This commit is contained in:
@@ -15,7 +15,7 @@ import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
@@ -73,7 +73,7 @@ def _get_prefix_suffix(prompt: str) -> Tuple[str, str]:
|
||||
|
||||
|
||||
async def process_sample(
|
||||
client: Any, sample: dict, sampling_params: dict
|
||||
client: Any, sample: dict, sampling_params: dict, lora_path: Optional[str] = None
|
||||
) -> Tuple[dict, str]:
|
||||
"""Send a single sample to the LLM and return (sample, response)."""
|
||||
prompt = sample["final_input_prompt"]
|
||||
@@ -81,6 +81,7 @@ async def process_sample(
|
||||
image = sample["image"]
|
||||
assert image is not None
|
||||
image_path = sample["image_path"]
|
||||
extra_body = None if lora_path is None else {"lora_path": lora_path}
|
||||
response = await client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
@@ -96,16 +97,21 @@ async def process_sample(
|
||||
temperature=0,
|
||||
max_completion_tokens=sampling_params["max_new_tokens"],
|
||||
max_tokens=sampling_params["max_new_tokens"],
|
||||
extra_body=extra_body,
|
||||
)
|
||||
return sample, response.choices[0].message.content
|
||||
|
||||
|
||||
async def process_sample_with_semaphore(
|
||||
semaphore: asyncio.Semaphore, client: Any, sample: dict, sampling_params: dict
|
||||
semaphore: asyncio.Semaphore,
|
||||
client: Any,
|
||||
sample: dict,
|
||||
sampling_params: dict,
|
||||
lora_path: Optional[str] = None,
|
||||
) -> Tuple[dict, str]:
|
||||
"""Wrap process_sample with a semaphore for concurrency control."""
|
||||
async with semaphore:
|
||||
return await process_sample(client, sample, sampling_params)
|
||||
return await process_sample(client, sample, sampling_params, lora_path)
|
||||
|
||||
|
||||
async def eval_mmmu(args) -> None:
|
||||
@@ -113,6 +119,7 @@ async def eval_mmmu(args) -> None:
|
||||
eval_args = EvalArgs.from_cli_args(args)
|
||||
sampling_params = get_sampling_params(eval_args)
|
||||
samples = prepare_samples(eval_args)
|
||||
lora_path = eval_args.lora_path
|
||||
answer_dict = {}
|
||||
out_samples = {}
|
||||
client = openai.AsyncOpenAI(
|
||||
@@ -133,7 +140,9 @@ async def eval_mmmu(args) -> None:
|
||||
samples = samples[: args.profile_number]
|
||||
|
||||
tasks = [
|
||||
process_sample_with_semaphore(semaphore, client, sample, sampling_params)
|
||||
process_sample_with_semaphore(
|
||||
semaphore, client, sample, sampling_params, lora_path
|
||||
)
|
||||
for sample in samples
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user