ci: refactor nightly test (#10495)
This commit is contained in:
@@ -443,11 +443,9 @@ def latency_test_run_once(
|
||||
|
||||
if profile:
|
||||
profiler.stop()
|
||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, profile_filename)
|
||||
rank_print(
|
||||
f"torch profiler chrome trace for prefill saved to {profile_filename}"
|
||||
)
|
||||
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, trace_filename)
|
||||
rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}")
|
||||
|
||||
# Decode
|
||||
decode_latencies = []
|
||||
@@ -479,10 +477,10 @@ def latency_test_run_once(
|
||||
|
||||
if profile and i == output_len / 2:
|
||||
profiler.stop()
|
||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, profile_filename)
|
||||
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, trace_filename)
|
||||
rank_print(
|
||||
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
|
||||
f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}"
|
||||
)
|
||||
|
||||
# Record decode timing from 2nd output
|
||||
|
||||
@@ -9,6 +9,7 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
|
||||
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
|
||||
python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -19,12 +20,17 @@ import multiprocessing
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sglang.bench_serving import get_tokenizer, sample_random_requests
|
||||
from sglang.bench_serving import (
|
||||
get_tokenizer,
|
||||
sample_mmmu_requests,
|
||||
sample_random_requests,
|
||||
)
|
||||
from sglang.profiler import run_profile
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -32,6 +38,109 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
|
||||
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
||||
|
||||
|
||||
class ProfileLinks(BaseModel):
|
||||
"""Pydantic model for profile trace links."""
|
||||
|
||||
extend: Optional[str] = None
|
||||
decode: Optional[str] = None
|
||||
|
||||
|
||||
class BenchmarkResult(BaseModel):
|
||||
"""Pydantic model for benchmark results table data, for a single isl and osl"""
|
||||
|
||||
model_path: str
|
||||
run_name: str
|
||||
batch_size: int
|
||||
input_len: int
|
||||
output_len: int
|
||||
latency: float
|
||||
ttft: float
|
||||
input_throughput: float
|
||||
output_throughput: float
|
||||
overall_throughput: float
|
||||
last_gen_throughput: float
|
||||
acc_length: Optional[float] = None
|
||||
profile_links: Optional[ProfileLinks] = None
|
||||
|
||||
@staticmethod
|
||||
def help_str() -> str:
|
||||
return f"""
|
||||
Note: To view the traces through perfetto-ui, please:
|
||||
1. use Google Chrome
|
||||
2. enable popup
|
||||
|
||||
"""
|
||||
|
||||
def to_markdown_row(
|
||||
self, trace_dir, base_url: str = "", relay_base: str = ""
|
||||
) -> str:
|
||||
"""Convert this benchmark result to a markdown table row."""
|
||||
# Calculate costs (assuming H100 pricing for now)
|
||||
hourly_cost_per_gpu = 2 # $2/hour for one H100
|
||||
hourly_cost = hourly_cost_per_gpu * 1 # Assuming tp_size = 1 for simplicity
|
||||
input_util = 0.7
|
||||
accept_length = (
|
||||
round(self.acc_length, 2) if self.acc_length is not None else "n/a"
|
||||
)
|
||||
itl = 1 / (self.output_throughput / self.batch_size) * 1000
|
||||
input_cost = 1e6 / (self.input_throughput * input_util) / 3600 * hourly_cost
|
||||
output_cost = 1e6 / self.output_throughput / 3600 * hourly_cost
|
||||
|
||||
def get_perfetto_relay_link_from_trace_file(trace_file: str):
|
||||
import os
|
||||
from urllib.parse import quote
|
||||
|
||||
rel_path = os.path.relpath(trace_file, trace_dir)
|
||||
raw_file_link = f"{base_url}/{rel_path}"
|
||||
relay_link = (
|
||||
f"{relay_base}?src={quote(raw_file_link, safe='')}"
|
||||
if relay_base and quote
|
||||
else raw_file_link
|
||||
)
|
||||
return relay_link
|
||||
|
||||
# Handle profile links
|
||||
profile_link = "NA | NA"
|
||||
if self.profile_links:
|
||||
if self.profile_links.extend or self.profile_links.decode:
|
||||
# Create a combined link or use the first available one
|
||||
trace_files = [self.profile_links.extend, self.profile_links.decode]
|
||||
trace_files_relay_links = [
|
||||
f"[trace]({get_perfetto_relay_link_from_trace_file(trace_file)})"
|
||||
for trace_file in trace_files
|
||||
]
|
||||
|
||||
profile_link = " | ".join(trace_files_relay_links)
|
||||
|
||||
# Build the row
|
||||
return f"| {self.batch_size} | {self.input_len} | {self.latency:.2f} | {self.input_throughput:.2f} | {self.output_throughput:.2f} | {accept_length} | {itl:.2f} | {input_cost:.2f} | {output_cost:.2f} | {profile_link} |\n"
|
||||
|
||||
@classmethod
|
||||
def generate_markdown_report(
|
||||
cls, trace_dir, results: List["BenchmarkResult"]
|
||||
) -> str:
|
||||
"""Generate a markdown report from a list of BenchmarkResult object from a single run."""
|
||||
import os
|
||||
|
||||
summary = f"### {results[0].model_path}\n"
|
||||
|
||||
# summary += (
|
||||
# f"Input lens: {result.input_len}. Output lens: {result.output_len}.\n"
|
||||
# )
|
||||
summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) | profile (extend) | profile (decode)|\n"
|
||||
summary += "| ---------- | --------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ | --------------- | -------------- |\n"
|
||||
|
||||
# all results should share the same isl & osl
|
||||
for result in results:
|
||||
base_url = os.getenv("TRACE_BASE_URL", "").rstrip("/")
|
||||
relay_base = os.getenv("PERFETTO_RELAY_URL", "").rstrip("/")
|
||||
relay_base = "https://docs.sglang.ai/ci-data/pages/perfetto_relay.html"
|
||||
# base_url = "https://github.com/sgl-project/ci-data/traces"
|
||||
summary += result.to_markdown_row(trace_dir, base_url, relay_base)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
run_name: str = "default"
|
||||
@@ -50,8 +159,12 @@ class BenchArgs:
|
||||
profile: bool = False
|
||||
profile_steps: int = 3
|
||||
profile_by_stage: bool = False
|
||||
profile_filename_prefix: str = None
|
||||
append_to_github_summary: bool = True
|
||||
dataset_path: str = ""
|
||||
parallel_batch: bool = False
|
||||
dataset_name: str = "random"
|
||||
output_path: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -67,6 +180,13 @@ class BenchArgs:
|
||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||
)
|
||||
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default=BenchArgs.dataset_name,
|
||||
choices=["mmmu", "random"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument("--return-logprob", action="store_true")
|
||||
parser.add_argument(
|
||||
"--client-stream-interval",
|
||||
@@ -96,14 +216,36 @@ class BenchArgs:
|
||||
help="Path to the dataset.",
|
||||
)
|
||||
parser.add_argument("--parallel-batch", action="store_true")
|
||||
parser.add_argument(
|
||||
"--profile-filename-prefix",
|
||||
type=str,
|
||||
default=BenchArgs.profile_filename_prefix,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-append-to-github-summary",
|
||||
action="store_false",
|
||||
dest="append_to_github_summary",
|
||||
help="Disable appending the output of this run to github ci summary",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-path",
|
||||
type=str,
|
||||
default=BenchArgs.output_path,
|
||||
help="Path to save benchmark results as JSON format. If not specified, results will only be saved to result-filename.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
# use the default value's type to cast the args into correct types.
|
||||
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
||||
return cls(
|
||||
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
||||
)
|
||||
kwargs = {}
|
||||
for attr, attr_type in attrs:
|
||||
val = getattr(args, attr)
|
||||
if attr_type is type(None):
|
||||
kwargs[attr] = val
|
||||
else:
|
||||
kwargs[attr] = attr_type(val)
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
def launch_server_internal(server_args):
|
||||
@@ -148,23 +290,35 @@ def run_one_case(
|
||||
run_name: str,
|
||||
result_filename: str,
|
||||
tokenizer,
|
||||
dataset_name="",
|
||||
profile: bool = False,
|
||||
profile_steps: int = 3,
|
||||
profile_by_stage: bool = False,
|
||||
profile_filename_prefix: str = None,
|
||||
dataset_path: str = "",
|
||||
parallel_batch: bool = False,
|
||||
):
|
||||
requests.post(url + "/flush_cache")
|
||||
input_requests = sample_random_requests(
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
num_prompts=batch_size,
|
||||
range_ratio=1.0,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=dataset_path,
|
||||
random_sample=True,
|
||||
return_text=False,
|
||||
)
|
||||
# TODO: reuse bench_serving.get_dataset ?
|
||||
if dataset_name == "mmmu":
|
||||
input_requests = sample_mmmu_requests(
|
||||
num_requests=batch_size,
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=output_len,
|
||||
apply_chat_template=True,
|
||||
random_sample=False,
|
||||
)
|
||||
elif dataset_name == "random":
|
||||
input_requests = sample_random_requests(
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
num_prompts=batch_size,
|
||||
range_ratio=1.0,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=dataset_path,
|
||||
random_sample=True,
|
||||
return_text=False,
|
||||
)
|
||||
|
||||
use_structured_outputs = False
|
||||
if use_structured_outputs:
|
||||
@@ -181,26 +335,48 @@ def run_one_case(
|
||||
|
||||
profile_link = None
|
||||
if profile:
|
||||
output_dir, profile_name = None, None
|
||||
if profile_filename_prefix:
|
||||
output_dir = os.path.dirname(profile_filename_prefix)
|
||||
profile_name = os.path.basename(profile_filename_prefix)
|
||||
profile_link: str = run_profile(
|
||||
url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
|
||||
url,
|
||||
profile_steps,
|
||||
["CPU", "GPU"],
|
||||
output_dir,
|
||||
profile_name,
|
||||
profile_by_stage,
|
||||
)
|
||||
|
||||
tic = time.perf_counter()
|
||||
|
||||
payload = {
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
"json_schema": json_schema,
|
||||
"stream_interval": stream_interval,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"stream": True,
|
||||
**({"parallel_batch": parallel_batch} if parallel_batch else {}),
|
||||
}
|
||||
if dataset_name == "mmmu":
|
||||
# vlm
|
||||
input_ids = []
|
||||
for input_req in input_requests:
|
||||
input_ids += [tokenizer.encode(input_req.prompt)]
|
||||
payload["image_data"] = [req.image_data for req in input_requests]
|
||||
|
||||
else:
|
||||
input_ids = [req.prompt for req in input_requests]
|
||||
|
||||
payload["input_ids"] = input_ids
|
||||
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
"input_ids": [req.prompt for req in input_requests],
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
"json_schema": json_schema,
|
||||
"stream_interval": stream_interval,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"stream": True,
|
||||
**({"parallel_batch": parallel_batch} if parallel_batch else {}),
|
||||
},
|
||||
json=payload,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@@ -264,10 +440,100 @@ def run_one_case(
|
||||
overall_throughput,
|
||||
last_gen_throughput,
|
||||
acc_length,
|
||||
profile_link if profile else None,
|
||||
profile_link,
|
||||
)
|
||||
|
||||
|
||||
def save_results_as_json(result: List[Tuple], bench_args: BenchArgs, model: str):
|
||||
"""Save benchmark results as JSON using Pydantic models."""
|
||||
json_results = []
|
||||
|
||||
# Generate all parameter combinations to match with results
|
||||
param_combinations = list(
|
||||
itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
)
|
||||
)
|
||||
|
||||
for i, (
|
||||
batch_size,
|
||||
latency,
|
||||
ttft,
|
||||
input_throughput,
|
||||
output_throughput,
|
||||
overall_throughput,
|
||||
last_gen_throughput,
|
||||
acc_length,
|
||||
profile_link,
|
||||
) in enumerate(result):
|
||||
# Get the corresponding parameters for this result
|
||||
bs, input_len, output_len = param_combinations[i]
|
||||
|
||||
# Parse profile links if available
|
||||
profile_links = None
|
||||
if profile_link:
|
||||
profile_links = parse_profile_links(
|
||||
profile_link, batch_size, input_len, output_len
|
||||
)
|
||||
|
||||
benchmark_result = BenchmarkResult(
|
||||
model_path=model,
|
||||
run_name=bench_args.run_name,
|
||||
batch_size=batch_size,
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
latency=latency,
|
||||
ttft=ttft,
|
||||
input_throughput=input_throughput,
|
||||
output_throughput=output_throughput,
|
||||
overall_throughput=overall_throughput,
|
||||
last_gen_throughput=last_gen_throughput,
|
||||
acc_length=acc_length,
|
||||
profile_links=profile_links,
|
||||
)
|
||||
json_results.append(benchmark_result.model_dump())
|
||||
|
||||
# Save to JSON file
|
||||
with open(bench_args.output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(json_results, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"Results saved as JSON to {bench_args.output_path}")
|
||||
|
||||
|
||||
def parse_profile_links(
|
||||
profile_dir: str, batch_size: int, input_len: int, output_len: int
|
||||
) -> Optional[ProfileLinks]:
|
||||
"""Parse profile directory to extract extend and decode trace file links."""
|
||||
if not profile_dir or not os.path.exists(profile_dir):
|
||||
return None
|
||||
|
||||
extend_link = None
|
||||
decode_link = None
|
||||
|
||||
# Look for extend/prefill trace files
|
||||
for file in os.listdir(profile_dir):
|
||||
if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
|
||||
if "extend" in file.lower() or "prefill" in file.lower():
|
||||
extend_link = os.path.join(profile_dir, file)
|
||||
elif "decode" in file.lower():
|
||||
decode_link = os.path.join(profile_dir, file)
|
||||
|
||||
# If no specific extend/decode files found, try to find files with batch/input/output info
|
||||
if not extend_link or not decode_link:
|
||||
for file in os.listdir(profile_dir):
|
||||
if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
|
||||
if f"_batch{batch_size}_input{input_len}_output{output_len}_" in file:
|
||||
if "prefill" in file.lower() or "extend" in file.lower():
|
||||
extend_link = os.path.join(profile_dir, file)
|
||||
elif "decode" in file.lower():
|
||||
decode_link = os.path.join(profile_dir, file)
|
||||
|
||||
if extend_link or decode_link:
|
||||
return ProfileLinks(extend=extend_link, decode=decode_link)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_report_summary(
|
||||
result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
|
||||
):
|
||||
@@ -358,6 +624,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
return_logprob=bench_args.return_logprob,
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
dataset_name=bench_args.dataset_name,
|
||||
run_name="",
|
||||
result_filename="",
|
||||
tokenizer=tokenizer,
|
||||
@@ -384,10 +651,12 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
stream_interval=bench_args.client_stream_interval,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name=bench_args.run_name,
|
||||
dataset_name=bench_args.dataset_name,
|
||||
result_filename=bench_args.result_filename,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=bench_args.dataset_path,
|
||||
parallel_batch=bench_args.parallel_batch,
|
||||
profile_filename_prefix=bench_args.profile_filename_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -410,11 +679,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
run_name=bench_args.run_name,
|
||||
result_filename=bench_args.result_filename,
|
||||
tokenizer=tokenizer,
|
||||
dataset_name=bench_args.dataset_name,
|
||||
profile=bench_args.profile,
|
||||
profile_steps=bench_args.profile_steps,
|
||||
profile_by_stage=bench_args.profile_by_stage,
|
||||
dataset_path=bench_args.dataset_path,
|
||||
parallel_batch=bench_args.parallel_batch,
|
||||
profile_filename_prefix=bench_args.profile_filename_prefix,
|
||||
)[-1],
|
||||
)
|
||||
)
|
||||
@@ -427,13 +698,16 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
|
||||
print(f"\nResults are saved to {bench_args.result_filename}")
|
||||
|
||||
# Save results as JSON if output_path is specified
|
||||
if bench_args.output_path:
|
||||
save_results_as_json(result, bench_args, model=server_args.model_path)
|
||||
|
||||
if not bench_args.show_report:
|
||||
return
|
||||
|
||||
summary = get_report_summary(result, server_args, bench_args)
|
||||
print(summary)
|
||||
|
||||
if is_in_ci():
|
||||
if is_in_ci() and bench_args.append_to_github_summary:
|
||||
write_github_step_summary(summary)
|
||||
|
||||
|
||||
|
||||
@@ -208,6 +208,10 @@ async def async_request_openai_completions(
|
||||
"ignore_eos": not args.disable_ignore_eos,
|
||||
**request_func_input.extra_request_body,
|
||||
}
|
||||
|
||||
if request_func_input.image_data:
|
||||
payload.update({"image_data": request_func_input.image_data})
|
||||
|
||||
headers = get_auth_headers()
|
||||
|
||||
output = RequestFuncOutput.init_new(request_func_input)
|
||||
@@ -664,7 +668,7 @@ def get_dataset(args, tokenizer):
|
||||
num_prompts=args.num_prompts,
|
||||
range_ratio=args.random_range_ratio,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_path=args.dataset_name,
|
||||
random_sample=args.dataset_name == "random",
|
||||
return_text=not tokenize_prompt,
|
||||
)
|
||||
|
||||
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
|
||||
def start_profile(
|
||||
self, stage: Optional[ForwardMode] = None
|
||||
) -> ProfileReqOutput | None:
|
||||
stage_str = f" for {stage.__str__()}" if stage else ""
|
||||
stage_str = f" for {stage.name}" if stage else ""
|
||||
logger.info(
|
||||
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
|
||||
)
|
||||
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
|
||||
if not Path(self.torch_profiler_output_dir).exists():
|
||||
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stage_suffix = f"-{stage.__str__()}" if stage else ""
|
||||
stage_suffix = f"-{stage.name}" if stage else ""
|
||||
logger.info("Stop profiling" + stage_suffix + "...")
|
||||
if self.torch_profiler is not None:
|
||||
self.torch_profiler.stop()
|
||||
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
|
||||
if self.profiler_decode_ct == 0:
|
||||
if self.profile_in_progress:
|
||||
# force trace flush
|
||||
self.stop_profile(ForwardMode.EXTEND)
|
||||
self.stop_profile(stage=ForwardMode.EXTEND)
|
||||
self.start_profile(batch.forward_mode)
|
||||
self.profiler_decode_ct += 1
|
||||
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
||||
|
||||
@@ -60,6 +60,11 @@ def run_eval(args):
|
||||
from sglang.test.simple_eval_humaneval import HumanEval
|
||||
|
||||
eval_obj = HumanEval(args.num_examples, args.num_threads)
|
||||
elif args.eval_name == "mmmu":
|
||||
# VLM MMMU evaluation with fixed 100 examples by default
|
||||
from sglang.test.simple_eval_mmmu_vlm import MMMUVLMEval
|
||||
|
||||
eval_obj = MMMUVLMEval(args.num_examples, args.num_threads)
|
||||
else:
|
||||
raise ValueError(f"Invalid eval name: {args.eval_name}")
|
||||
|
||||
@@ -94,6 +99,8 @@ def run_eval(args):
|
||||
print(f"Total latency: {latency:.3f} s")
|
||||
print(f"Score: {metrics['score']:.3f}")
|
||||
|
||||
if getattr(args, "return_latency", False):
|
||||
return metrics, latency
|
||||
return metrics
|
||||
|
||||
|
||||
|
||||
441
python/sglang/test/simple_eval_mmmu_vlm.py
Normal file
441
python/sglang/test/simple_eval_mmmu_vlm.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""
|
||||
MMMU evaluation for VLMs using the run_eval simple-evals interface.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from PIL import Image
|
||||
|
||||
from sglang.test import simple_eval_common as common
|
||||
from sglang.test.simple_eval_common import (
|
||||
HTML_JINJA,
|
||||
Eval,
|
||||
EvalResult,
|
||||
SamplerBase,
|
||||
SingleEvalResult,
|
||||
map_with_progress,
|
||||
)
|
||||
|
||||
|
||||
class MMMUVLMEval(Eval):
|
||||
DOMAIN_CAT2SUB_CAT = {
|
||||
"Art and Design": ["Art", "Art_Theory", "Design", "Music"],
|
||||
"Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
|
||||
"Science": ["Biology", "Chemistry", "Geography", "Math", "Physics"],
|
||||
"Health and Medicine": [
|
||||
"Basic_Medical_Science",
|
||||
"Clinical_Medicine",
|
||||
"Diagnostics_and_Laboratory_Medicine",
|
||||
"Pharmacy",
|
||||
"Public_Health",
|
||||
],
|
||||
"Humanities and Social Science": [
|
||||
"History",
|
||||
"Literature",
|
||||
"Sociology",
|
||||
"Psychology",
|
||||
],
|
||||
"Tech and Engineering": [
|
||||
"Agriculture",
|
||||
"Architecture_and_Engineering",
|
||||
"Computer_Science",
|
||||
"Electronics",
|
||||
"Energy_and_Power",
|
||||
"Materials",
|
||||
"Mechanical_Engineering",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, num_examples: Optional[int] = 100, num_threads: int = 32, seed: int = 42
|
||||
):
|
||||
"""Create MMMU VLM eval (Math subset, 100 fixed samples by default)."""
|
||||
self.num_examples = num_examples
|
||||
self.num_threads = num_threads
|
||||
self.seed = seed
|
||||
# Prepare samples deterministically across all MMMU subjects (validation split)
|
||||
self.samples = self._prepare_mmmu_samples(self.num_examples)
|
||||
|
||||
@staticmethod
|
||||
def _to_data_uri(image: Image.Image) -> str:
|
||||
if image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format="PNG")
|
||||
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
return f"data:image/png;base64,{b64}"
|
||||
|
||||
@staticmethod
|
||||
def _build_mc_mapping(options: List[str]) -> Tuple[dict, List[str]]:
|
||||
index2ans = {}
|
||||
all_choices = []
|
||||
ch = ord("A")
|
||||
for opt in options:
|
||||
letter = chr(ch)
|
||||
index2ans[letter] = opt
|
||||
all_choices.append(letter)
|
||||
ch += 1
|
||||
return index2ans, all_choices
|
||||
|
||||
def _prepare_mmmu_samples(self, k: int) -> List[dict]:
|
||||
# Subjects and domains copied from MMMU data_utils to categorize results
|
||||
subjects: List[str] = []
|
||||
for subs in self.DOMAIN_CAT2SUB_CAT.values():
|
||||
subjects.extend(subs)
|
||||
|
||||
# Load validation split of each subject
|
||||
datasets = []
|
||||
for subj in subjects:
|
||||
try:
|
||||
d = load_dataset("MMMU/MMMU", subj, split="validation")
|
||||
# attach subject info via transform
|
||||
d = d.add_column("__subject__", [subj] * len(d))
|
||||
datasets.append(d)
|
||||
except Exception:
|
||||
continue
|
||||
if not datasets:
|
||||
raise RuntimeError("Failed to load MMMU datasets")
|
||||
|
||||
merged = concatenate_datasets(datasets)
|
||||
|
||||
# Deterministic selection: sort by id (fallback to subject+index)
|
||||
def _key(idx):
|
||||
ex = merged[idx]
|
||||
return str(ex.get("id", f"{ex['__subject__']}:{idx}"))
|
||||
|
||||
order = sorted(range(len(merged)), key=_key)
|
||||
picked_indices = order[:k]
|
||||
|
||||
samples: List[dict] = []
|
||||
for idx in picked_indices:
|
||||
ex = merged[idx]
|
||||
subject = ex["__subject__"]
|
||||
image = ex.get("image_1")
|
||||
if image is None or not hasattr(image, "convert"):
|
||||
continue
|
||||
data_uri = self._to_data_uri(image)
|
||||
question = ex.get("question", "")
|
||||
answer = ex.get("answer")
|
||||
raw_options = ex.get("options")
|
||||
question_type = "open"
|
||||
index2ans = None
|
||||
all_choices = None
|
||||
options = None
|
||||
if raw_options:
|
||||
try:
|
||||
options = (
|
||||
raw_options
|
||||
if isinstance(raw_options, list)
|
||||
else list(eval(raw_options))
|
||||
)
|
||||
if isinstance(options, list) and len(options) > 0:
|
||||
index2ans, all_choices = self._build_mc_mapping(options)
|
||||
question_type = "multiple-choice"
|
||||
except Exception:
|
||||
options = None
|
||||
|
||||
# Build final textual prompt; include choices if MC
|
||||
prompt_text = f"Question: {question}\n\n"
|
||||
if options:
|
||||
letters = [chr(ord("A") + i) for i in range(len(options))]
|
||||
for letter, opt in zip(letters, options):
|
||||
prompt_text += f"{letter}) {opt}\n"
|
||||
prompt_text += "\nAnswer: "
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"id": ex.get("id", f"{subject}:{idx}"),
|
||||
"final_input_prompt": prompt_text,
|
||||
"image_data": data_uri,
|
||||
"answer": answer,
|
||||
"question_type": question_type,
|
||||
"index2ans": index2ans,
|
||||
"all_choices": all_choices,
|
||||
"category": subject,
|
||||
}
|
||||
)
|
||||
|
||||
return samples
|
||||
|
||||
@staticmethod
|
||||
def _split_prompt_for_image(prompt: str) -> tuple[str, str]:
|
||||
"""Split a prompt containing an inline image tag into prefix and suffix.
|
||||
|
||||
If no tag is present, treat the whole prompt as prefix and empty suffix.
|
||||
"""
|
||||
if "<" in prompt and ">" in prompt:
|
||||
prefix = prompt.split("<")[0]
|
||||
suffix = prompt.split(">", 1)[1]
|
||||
return prefix, suffix
|
||||
return prompt, ""
|
||||
|
||||
@staticmethod
|
||||
def build_chat_messages_from_prompt(prompt: str, image_data) -> List:
|
||||
"""Split a prompt containing an inline image tag into prefix and suffix.
|
||||
|
||||
If no tag is present, treat the whole prompt as prefix and empty suffix.
|
||||
"""
|
||||
# Build a vision+text message for OpenAI-compatible API
|
||||
prefix, suffix = MMMUVLMEval._split_prompt_for_image(prompt)
|
||||
|
||||
content: List[dict] = []
|
||||
if prefix:
|
||||
content.append({"type": "text", "text": prefix})
|
||||
content.append({"type": "image_url", "image_url": {"url": image_data}})
|
||||
if suffix:
|
||||
content.append({"type": "text", "text": suffix})
|
||||
prompt_messages = [{"role": "user", "content": content}]
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
def fn(sample: dict):
|
||||
prompt = sample["final_input_prompt"]
|
||||
image_data = sample["image_data"]
|
||||
prompt_messages = MMMUVLMEval.build_chat_messages_from_prompt(
|
||||
prompt, image_data
|
||||
)
|
||||
|
||||
# Sample
|
||||
response_text = sampler(prompt_messages)
|
||||
|
||||
# Parse and score
|
||||
gold = sample["answer"]
|
||||
if (
|
||||
sample["question_type"] == "multiple-choice"
|
||||
and sample["all_choices"]
|
||||
and sample["index2ans"]
|
||||
):
|
||||
pred = _parse_multi_choice_response(
|
||||
response_text, sample["all_choices"], sample["index2ans"]
|
||||
)
|
||||
score = 1.0 if (gold is not None and pred == gold) else 0.0
|
||||
extracted_answer = pred
|
||||
else:
|
||||
parsed_list = _parse_open_response(response_text)
|
||||
score = (
|
||||
1.0 if (gold is not None and _eval_open(gold, parsed_list)) else 0.0
|
||||
)
|
||||
extracted_answer = ", ".join(map(str, parsed_list))
|
||||
|
||||
html_rendered = common.jinja_env.from_string(HTML_JINJA).render(
|
||||
prompt_messages=prompt_messages,
|
||||
next_message=dict(content=response_text, role="assistant"),
|
||||
score=score,
|
||||
correct_answer=gold,
|
||||
extracted_answer=extracted_answer,
|
||||
)
|
||||
|
||||
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
||||
return SingleEvalResult(
|
||||
html=html_rendered,
|
||||
score=score,
|
||||
metrics={"__category__": sample["category"]},
|
||||
convo=convo,
|
||||
)
|
||||
|
||||
results = map_with_progress(fn, self.samples, self.num_threads)
|
||||
|
||||
# Build category table and overall accuracy
|
||||
# Gather per-sample correctness and category
|
||||
per_cat_total: dict[str, int] = {}
|
||||
per_cat_correct: dict[str, int] = {}
|
||||
htmls = []
|
||||
convos = []
|
||||
scores: List[float] = []
|
||||
for r in results:
|
||||
# __category__ stored under metrics
|
||||
cat = r.metrics.get("__category__") if r.metrics else None
|
||||
if cat is None:
|
||||
cat = "Unknown"
|
||||
per_cat_total[cat] = per_cat_total.get(cat, 0) + 1
|
||||
if r.score:
|
||||
per_cat_correct[cat] = per_cat_correct.get(cat, 0) + 1
|
||||
htmls.append(r.html)
|
||||
convos.append(r.convo)
|
||||
if r.score is not None:
|
||||
scores.append(r.score)
|
||||
|
||||
evaluation_result = {}
|
||||
for cat, tot in per_cat_total.items():
|
||||
corr = per_cat_correct.get(cat, 0)
|
||||
acc = (corr / tot) if tot > 0 else 0.0
|
||||
evaluation_result[cat] = {"acc": round(acc, 3), "num_example": tot}
|
||||
|
||||
printable_results = {}
|
||||
# Domains first
|
||||
for domain, cats in self.DOMAIN_CAT2SUB_CAT.items():
|
||||
acc_sum = 0.0
|
||||
num_sum = 0
|
||||
for cat in cats:
|
||||
if cat in evaluation_result:
|
||||
acc_sum += (
|
||||
evaluation_result[cat]["acc"]
|
||||
* evaluation_result[cat]["num_example"]
|
||||
)
|
||||
num_sum += evaluation_result[cat]["num_example"]
|
||||
if num_sum > 0:
|
||||
printable_results[f"Overall-{domain}"] = {
|
||||
"num": num_sum,
|
||||
"acc": round(acc_sum / num_sum, 3),
|
||||
}
|
||||
# add each sub-category row if present
|
||||
for cat in cats:
|
||||
if cat in evaluation_result:
|
||||
printable_results[cat] = {
|
||||
"num": evaluation_result[cat]["num_example"],
|
||||
"acc": evaluation_result[cat]["acc"],
|
||||
}
|
||||
|
||||
# Overall
|
||||
total_num = sum(v["num_example"] for v in evaluation_result.values())
|
||||
overall_acc = (
|
||||
sum(v["acc"] * v["num_example"] for v in evaluation_result.values())
|
||||
/ total_num
|
||||
if total_num > 0
|
||||
else 0.0
|
||||
)
|
||||
printable_results["Overall"] = {"num": total_num, "acc": round(overall_acc, 3)}
|
||||
|
||||
# Build EvalResult
|
||||
return EvalResult(
|
||||
score=overall_acc, metrics=printable_results, htmls=htmls, convos=convos
|
||||
)
|
||||
|
||||
|
||||
def _parse_multi_choice_response(
|
||||
response: str, all_choices: List[str], index2ans: dict
|
||||
) -> str:
|
||||
# loosely adapted from benchmark mmmu eval
|
||||
for char in [",", ".", "!", "?", ";", ":", "'"]:
|
||||
response = response.strip(char)
|
||||
response = " " + response + " "
|
||||
|
||||
# Prefer explicit letter with bracket e.g. (A)
|
||||
candidates: List[str] = []
|
||||
for choice in all_choices:
|
||||
if f"({choice})" in response:
|
||||
candidates.append(choice)
|
||||
if not candidates:
|
||||
for choice in all_choices:
|
||||
if f" {choice} " in response:
|
||||
candidates.append(choice)
|
||||
if not candidates and len(response.split()) > 5:
|
||||
# try match by option text
|
||||
for idx, ans in index2ans.items():
|
||||
if ans and ans.lower() in response.lower():
|
||||
candidates.append(idx)
|
||||
if not candidates:
|
||||
# fallback to first choice
|
||||
return all_choices[0]
|
||||
if len(candidates) == 1:
|
||||
return candidates[0]
|
||||
# choose the last occurrence
|
||||
starts = []
|
||||
for can in candidates:
|
||||
pos = response.rfind(f"({can})")
|
||||
if pos == -1:
|
||||
pos = response.rfind(f" {can} ")
|
||||
if pos == -1 and index2ans.get(can):
|
||||
pos = response.lower().rfind(index2ans[can].lower())
|
||||
starts.append(pos)
|
||||
return candidates[int(max(range(len(starts)), key=lambda i: starts[i]))]
|
||||
|
||||
|
||||
def _check_is_number(s: str) -> bool:
|
||||
try:
|
||||
float(s.replace(",", ""))
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_str(s: str):
|
||||
s = s.strip()
|
||||
if _check_is_number(s):
|
||||
s = s.replace(",", "")
|
||||
try:
|
||||
v = round(float(s), 2)
|
||||
return [v]
|
||||
except Exception:
|
||||
return [s.lower()]
|
||||
return [s.lower()] if len(s) > 1 else [" " + s, s + " "]
|
||||
|
||||
|
||||
def _extract_numbers(s: str) -> List[str]:
|
||||
import re as _re
|
||||
|
||||
pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
|
||||
pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
|
||||
pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
|
||||
return (
|
||||
_re.findall(pattern_commas, s)
|
||||
+ _re.findall(pattern_scientific, s)
|
||||
+ _re.findall(pattern_simple, s)
|
||||
)
|
||||
|
||||
|
||||
def _parse_open_response(response: str) -> List[str]:
|
||||
import re as _re
|
||||
|
||||
def get_key_subresponses(resp: str) -> List[str]:
|
||||
resp = resp.strip().strip(".").lower()
|
||||
subs = _re.split(r"\.\s(?=[A-Z])|\n", resp)
|
||||
indicators = [
|
||||
"could be ",
|
||||
"so ",
|
||||
"is ",
|
||||
"thus ",
|
||||
"therefore ",
|
||||
"final ",
|
||||
"answer ",
|
||||
"result ",
|
||||
]
|
||||
keys = []
|
||||
for i, s in enumerate(subs):
|
||||
cands = [*indicators]
|
||||
if i == len(subs) - 1:
|
||||
cands.append("=")
|
||||
shortest = None
|
||||
for ind in cands:
|
||||
if ind in s:
|
||||
part = s.split(ind)[-1].strip()
|
||||
if not shortest or len(part) < len(shortest):
|
||||
shortest = part
|
||||
if shortest and shortest not in [":", ",", ".", "!", "?", ";", ":", "'"]:
|
||||
keys.append(shortest)
|
||||
return keys or [resp]
|
||||
|
||||
key_resps = get_key_subresponses(response)
|
||||
pred_list = key_resps.copy()
|
||||
for r in key_resps:
|
||||
pred_list.extend(_extract_numbers(r))
|
||||
out = []
|
||||
for x in pred_list:
|
||||
out.extend(_normalize_str(x))
|
||||
# dedup
|
||||
return list(dict.fromkeys(out))
|
||||
|
||||
|
||||
def _eval_open(gold, preds: List[str]) -> bool:
|
||||
if isinstance(gold, list):
|
||||
norm_answers = []
|
||||
for ans in gold:
|
||||
norm_answers.extend(_normalize_str(ans))
|
||||
else:
|
||||
norm_answers = _normalize_str(gold)
|
||||
for p in preds:
|
||||
if isinstance(p, str):
|
||||
for na in norm_answers:
|
||||
if isinstance(na, str) and na in p:
|
||||
return True
|
||||
else:
|
||||
if p in norm_answers:
|
||||
return True
|
||||
return False
|
||||
@@ -14,10 +14,12 @@ import time
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
@@ -1467,3 +1469,137 @@ def dump_bench_raw_result(
|
||||
def _ensure_remove_suffix(text: str, suffix: str):
|
||||
assert text.endswith(suffix)
|
||||
return text.removesuffix(suffix)
|
||||
|
||||
|
||||
class ModelDeploySetup:
|
||||
def __init__(self, model_path: str, extra_args: List[str] = []):
|
||||
self.model_path = model_path
|
||||
if "--enable-multimodal" not in extra_args:
|
||||
extra_args.append("--enable-multimodal")
|
||||
if "--trust-remote-code" not in extra_args:
|
||||
extra_args.append("--trust-remote-code")
|
||||
|
||||
self.extra_args = extra_args
|
||||
|
||||
|
||||
class ModelEvalMetrics:
|
||||
def __init__(self, accuracy: float, eval_time: float):
|
||||
self.accuracy = accuracy
|
||||
self.eval_time = eval_time
|
||||
|
||||
|
||||
def extract_trace_link_from_bench_one_batch_server_output(output: str) -> str:
|
||||
match = re.search(r"\[Profile\]\((.*?)\)", output)
|
||||
if match:
|
||||
trace_link = match.group(1)
|
||||
return trace_link
|
||||
return None
|
||||
|
||||
|
||||
def parse_models(model_string: str):
|
||||
return [model.strip() for model in model_string.split(",") if model.strip()]
|
||||
|
||||
|
||||
def check_evaluation_test_results(
|
||||
results,
|
||||
test_name,
|
||||
model_accuracy_thresholds,
|
||||
model_latency_thresholds=None,
|
||||
model_count=None,
|
||||
):
|
||||
"""
|
||||
results: list of tuple of (model_path, accuracy, latency)
|
||||
"""
|
||||
failed_models = []
|
||||
if model_latency_thresholds is not None:
|
||||
summary = " | model | status | score | score_threshold | latency | latency_threshold | \n"
|
||||
summary += "| ----- | ------ | ----- | --------------- | ------- | ----------------- | \n"
|
||||
else:
|
||||
summary = " | model | status | score | score_threshold | \n"
|
||||
summary += "| ----- | ------ | ----- | --------------- | \n"
|
||||
|
||||
for model, accuracy, latency in results:
|
||||
accuracy_threshold = model_accuracy_thresholds.get(model)
|
||||
if accuracy_threshold is None:
|
||||
print(f"Warning: No threshold defined for model {model}")
|
||||
continue
|
||||
|
||||
latency_threshold = (
|
||||
model_latency_thresholds.get(model, None)
|
||||
if model_latency_thresholds
|
||||
else 1e9
|
||||
)
|
||||
|
||||
is_success = accuracy >= accuracy_threshold and latency <= latency_threshold
|
||||
status_emoji = "✅" if is_success else "❌"
|
||||
|
||||
if not is_success:
|
||||
failed_models.append(
|
||||
f"\nScore Check Failed: {model}\n"
|
||||
f"Model {model} score ({accuracy:.4f}) is below threshold ({accuracy_threshold:.4f})"
|
||||
)
|
||||
|
||||
if model_latency_thresholds is not None:
|
||||
line = f"| {model} | {status_emoji} | {accuracy} | {accuracy_threshold} | {latency} | {latency_threshold}\n"
|
||||
else:
|
||||
line = f"| {model} | {status_emoji} | {accuracy} | {accuracy_threshold}\n"
|
||||
|
||||
summary += line
|
||||
|
||||
print(summary)
|
||||
|
||||
if is_in_ci():
|
||||
write_github_step_summary(f"## {test_name}\n{summary}")
|
||||
|
||||
some_model_failed_to_get_result = len(results) != (
|
||||
model_count or len(model_accuracy_thresholds)
|
||||
)
|
||||
if some_model_failed_to_get_result:
|
||||
print("Some model has failed to launch and be evaluated")
|
||||
|
||||
if failed_models or some_model_failed_to_get_result:
|
||||
raise AssertionError("\n".join(failed_models))
|
||||
|
||||
|
||||
# Bench knobs for bench_one_batch_server (override by env)
|
||||
def _parse_int_list_env(name: str, default_val: str):
|
||||
val = os.environ.get(name, default_val)
|
||||
return [int(x) for x in val.split(",") if x]
|
||||
|
||||
|
||||
# Return filenames
|
||||
def find_traces_under_path(path: str) -> List[str]:
|
||||
results = []
|
||||
for _, dirs, files in os.walk(path):
|
||||
for file in files:
|
||||
if file.endswith(".trace.json.gz"):
|
||||
results.append(f"{file}")
|
||||
return results
|
||||
|
||||
|
||||
def write_results_to_json(model, metrics, mode="a"):
|
||||
result = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"metrics": metrics,
|
||||
"score": metrics["score"],
|
||||
}
|
||||
|
||||
if "latency" in metrics:
|
||||
result["latency"] = (metrics.get("latency"),)
|
||||
|
||||
existing_results = []
|
||||
if mode == "a" and os.path.exists("results.json"):
|
||||
try:
|
||||
with open("results.json", "r") as f:
|
||||
existing_results = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
existing_results = []
|
||||
|
||||
if isinstance(existing_results, list):
|
||||
existing_results.append(result)
|
||||
else:
|
||||
existing_results = [result]
|
||||
|
||||
with open("results.json", "w") as f:
|
||||
json.dump(existing_results, f, indent=2)
|
||||
|
||||
Reference in New Issue
Block a user