support vlm benchmark profile (#5905)
This commit is contained in:
@@ -10,8 +10,14 @@ The eval output will be logged
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import openai
|
import openai
|
||||||
from data_utils import save_json
|
from data_utils import save_json
|
||||||
from eval_utils import (
|
from eval_utils import (
|
||||||
@@ -25,8 +31,41 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse
|
from sglang.test.test_utils import add_common_sglang_args_and_parse
|
||||||
|
|
||||||
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
|
||||||
|
|
||||||
def eval_mmmu(args):
|
|
||||||
|
@dataclass
|
||||||
|
class RequestFuncOutput:
|
||||||
|
generated_text: List[str] = field(default_factory=list)
|
||||||
|
prompt_len: List[int] = field(default_factory=list)
|
||||||
|
output_len: List[int] = field(default_factory=list)
|
||||||
|
latency: List[float] = field(default_factory=list)
|
||||||
|
ttft: List[float] = field(default_factory=list)
|
||||||
|
itl: List[float] = field(default_factory=list) # List of inter-token latencies
|
||||||
|
|
||||||
|
success: bool = False
|
||||||
|
error: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
async def async_request_profile(api_url: str) -> RequestFuncOutput:
|
||||||
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
try:
|
||||||
|
async with session.post(url=api_url) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
output.success = True
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
|
except Exception:
|
||||||
|
output.success = False
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
output.error = "".join(traceback.format_exception(*exc_info))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def eval_mmmu(args):
|
||||||
eval_args = EvalArgs.from_cli_args(args)
|
eval_args = EvalArgs.from_cli_args(args)
|
||||||
|
|
||||||
out_samples = dict()
|
out_samples = dict()
|
||||||
@@ -38,9 +77,22 @@ def eval_mmmu(args):
|
|||||||
answer_dict = {}
|
answer_dict = {}
|
||||||
|
|
||||||
# had to use an openai server, since SglImage doesn't support image data
|
# had to use an openai server, since SglImage doesn't support image data
|
||||||
client = openai.Client(api_key="sk", base_url=f"http://127.0.0.1:{args.port}/v1")
|
base_url = f"http://127.0.0.1:{args.port}"
|
||||||
|
client = openai.Client(api_key="sk", base_url=f"{base_url}/v1")
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
|
if args.profile:
|
||||||
|
print("Starting profiler...")
|
||||||
|
profile_output = await async_request_profile(
|
||||||
|
api_url=f"{base_url}/start_profile"
|
||||||
|
)
|
||||||
|
if profile_output.success:
|
||||||
|
print("Profiler started")
|
||||||
|
|
||||||
|
if args.profile:
|
||||||
|
samples = samples[: args.profile_number]
|
||||||
|
|
||||||
for i, sample in enumerate(tqdm(samples)):
|
for i, sample in enumerate(tqdm(samples)):
|
||||||
prompt = sample["final_input_prompt"]
|
prompt = sample["final_input_prompt"]
|
||||||
prefix = prompt.split("<")[0]
|
prefix = prompt.split("<")[0]
|
||||||
@@ -49,6 +101,7 @@ def eval_mmmu(args):
|
|||||||
assert image is not None
|
assert image is not None
|
||||||
image_path = sample["image_path"]
|
image_path = sample["image_path"]
|
||||||
# TODO: batch
|
# TODO: batch
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model="default",
|
model="default",
|
||||||
messages=[
|
messages=[
|
||||||
@@ -77,6 +130,12 @@ def eval_mmmu(args):
|
|||||||
response = response.choices[0].message.content
|
response = response.choices[0].message.content
|
||||||
process_result(response, sample, answer_dict, out_samples)
|
process_result(response, sample, answer_dict, out_samples)
|
||||||
|
|
||||||
|
if args.profile:
|
||||||
|
print("Stopping profiler...")
|
||||||
|
profile_output = await async_request_profile(api_url=f"{base_url}/stop_profile")
|
||||||
|
if profile_output.success:
|
||||||
|
print("Profiler stopped")
|
||||||
|
|
||||||
print(f"Benchmark time: {time.time() - start}")
|
print(f"Benchmark time: {time.time() - start}")
|
||||||
|
|
||||||
args.output_path = f"./val_sglang.json"
|
args.output_path = f"./val_sglang.json"
|
||||||
@@ -89,4 +148,4 @@ if __name__ == "__main__":
|
|||||||
EvalArgs.add_cli_args(parser)
|
EvalArgs.add_cli_args(parser)
|
||||||
args = add_common_sglang_args_and_parse(parser)
|
args = add_common_sglang_args_and_parse(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
eval_mmmu(args)
|
asyncio.run(eval_mmmu(args))
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ class EvalArgs:
|
|||||||
prompt_format_file: str = "prompt_format.yaml"
|
prompt_format_file: str = "prompt_format.yaml"
|
||||||
dataset_path: str = "MMMU/MMMU"
|
dataset_path: str = "MMMU/MMMU"
|
||||||
extra_request_body: Optional[str] = None
|
extra_request_body: Optional[str] = None
|
||||||
|
profile: bool = False
|
||||||
|
profile_number: int = 5
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_cli_args(parser: argparse.ArgumentParser):
|
def add_cli_args(parser: argparse.ArgumentParser):
|
||||||
@@ -65,6 +67,12 @@ class EvalArgs:
|
|||||||
help="Append given JSON object to the request payload. You can use this to specify"
|
help="Append given JSON object to the request payload. You can use this to specify"
|
||||||
"additional generate params like sampling params.",
|
"additional generate params like sampling params.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--profile", action="store_true", help="enable mmmu profile"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--profile-number", type=int, default=EvalArgs.profile_number
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_cli_args(cls, args: argparse.Namespace):
|
def from_cli_args(cls, args: argparse.Namespace):
|
||||||
|
|||||||
Reference in New Issue
Block a user