support vlm benchmark profile (#5905)
This commit is contained in:
@@ -10,8 +10,14 @@ The eval output will be logged
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
from data_utils import save_json
|
||||
from eval_utils import (
|
||||
@@ -25,8 +31,41 @@ from tqdm import tqdm
|
||||
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
|
||||
|
||||
def eval_mmmu(args):
|
||||
|
||||
@dataclass
|
||||
class RequestFuncOutput:
|
||||
generated_text: List[str] = field(default_factory=list)
|
||||
prompt_len: List[int] = field(default_factory=list)
|
||||
output_len: List[int] = field(default_factory=list)
|
||||
latency: List[float] = field(default_factory=list)
|
||||
ttft: List[float] = field(default_factory=list)
|
||||
itl: List[float] = field(default_factory=list) # List of inter-token latencies
|
||||
|
||||
success: bool = False
|
||||
error: str = ""
|
||||
|
||||
|
||||
async def async_request_profile(api_url: str) -> RequestFuncOutput:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
output = RequestFuncOutput()
|
||||
try:
|
||||
async with session.post(url=api_url) as response:
|
||||
if response.status == 200:
|
||||
output.success = True
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
async def eval_mmmu(args):
|
||||
eval_args = EvalArgs.from_cli_args(args)
|
||||
|
||||
out_samples = dict()
|
||||
@@ -38,9 +77,22 @@ def eval_mmmu(args):
|
||||
answer_dict = {}
|
||||
|
||||
# had to use an openai server, since SglImage doesn't support image data
|
||||
client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1")
|
||||
base_url = f"http://127.0.0.1:{args.port}"
|
||||
client = openai.Client(api_key="sk", base_url=f"{base_url}/v1")
|
||||
|
||||
start = time.time()
|
||||
|
||||
if args.profile:
|
||||
print("Starting profiler...")
|
||||
profile_output = await async_request_profile(
|
||||
api_url=f"{base_url}/start_profile"
|
||||
)
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
|
||||
if args.profile:
|
||||
samples = samples[: args.profile_number]
|
||||
|
||||
for i, sample in enumerate(tqdm(samples)):
|
||||
prompt = sample["final_input_prompt"]
|
||||
prefix = prompt.split("<")[0]
|
||||
@@ -49,6 +101,7 @@ def eval_mmmu(args):
|
||||
assert image is not None
|
||||
image_path = sample["image_path"]
|
||||
# TODO: batch
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
@@ -77,6 +130,12 @@ def eval_mmmu(args):
|
||||
response = response.choices[0].message.content
|
||||
process_result(response, sample, answer_dict, out_samples)
|
||||
|
||||
if args.profile:
|
||||
print("Stopping profiler...")
|
||||
profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile")
|
||||
if profile_output.success:
|
||||
print("Profiler stopped")
|
||||
|
||||
print(f"Benchmark time: {time.time() - start}")
|
||||
|
||||
args.output_path = f"./val_sglang.json"
|
||||
@@ -89,4 +148,4 @@ if __name__ == "__main__":
|
||||
EvalArgs.add_cli_args(parser)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
args = parser.parse_args()
|
||||
eval_mmmu(args)
|
||||
asyncio.run(eval_mmmu(args))
|
||||
|
||||
@@ -33,6 +33,8 @@ class EvalArgs:
|
||||
prompt_format_file: str = "prompt_format.yaml"
|
||||
dataset_path: str = "MMMU/MMMU"
|
||||
extra_request_body: Optional[str] = None
|
||||
profile: bool = False
|
||||
profile_number: int = 5
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -65,6 +67,12 @@ class EvalArgs:
|
||||
help="Append given JSON object to the request payload. You can use this to specify"
|
||||
"additional generate params like sampling params.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile", action="store_true", help="enable mmmu profile"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-number", type=int, default=EvalArgs.profile_number
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
Reference in New Issue
Block a user