Files
enginex-vastai-va16-vllm/torch_vacc/testing/summarize_report.py
2026-04-02 04:55:00 +00:00

104 lines
2.5 KiB
Python

"""
Tool to summarize unit test XML reports, it summarize
* number of tests, and failure/error/skipped
* top 10 slowest tests
Usage:
python -m torch_vacc.testing.summarize_report --report report.xml
"""
import argparse
from dataclasses import dataclass
from xml.etree import ElementTree as ET
import sys
from torch_vacc import set_global_log_level
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--report", type=str)
return parser.parse_args()
def summarize_testsuites(suites):
summary = {
"errors": int,
"failures": int,
"skipped": int,
"skips": int,
"tests": int,
"time": float,
}
attribs = [s.attrib for s in suites]
for key in summary:
summary[key] = sum(summary[key](a[key]) for a in attribs if key in a)
assert not (summary["skipped"] and summary["skips"])
if summary["skips"]:
summary["skipped"] = summary["skips"]
return summary
def format_summary(summary):
template = "Ran {tests} tests in {time:.3f}s (errors={errors}, failures={failures}, skipped={skipped})"
msg = template.format(**summary)
if summary["errors"] > 0 or summary["failures"] > 0:
msg = "FAILED. " + msg
return msg
@dataclass
class TestCaseInfo:
test_class_name: str
test_name: str
time: float
timestamp: str
success: bool
def __lt__(self, other):
return self.time < other.time
def sort_cases_by_time(suites):
test_cases = [
TestCaseInfo(
s.attrib["classname"],
s.attrib["name"],
s.attrib["time"],
s.attrib["timestamp"],
s.attrib.get("failure") is None,
)
for s in suites
]
test_cases.sort(reverse=True)
return test_cases
def read_report(fpath):
with open(fpath) as report:
try:
report = ET.parse(report)
except ET.ParseError:
print(f"{sys.argv[0]}: Cannot parse file {fpath}", file=sys.stderr)
return
root = report.getroot()
suites = [root] if root.tag == "testsuite" else root.findall("testsuite")
summary = summarize_testsuites(suites)
summary_msg = format_summary(summary)
print(summary_msg)
for suite in suites:
cases = sort_cases_by_time(suite.findall("testcase"))
[print(case) for case in cases[:10]]
def main():
set_global_log_level("ERROR")
args = parse_args()
read_report(args.report)
if __name__ == "__main__":
main()