init
This commit is contained in:
62
torch_vacc/testing/__init__.py
Normal file
62
torch_vacc/testing/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from contextlib import contextmanager
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from torch.testing import make_tensor
|
||||
from functools import partial, wraps
|
||||
import torch.testing._internal.common_device_type as cdt
|
||||
from torch.testing._internal.common_device_type import (
|
||||
DeviceTypeTestBase,
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
onlyOn,
|
||||
onlyPRIVATEUSE1,
|
||||
ops,
|
||||
)
|
||||
|
||||
if sys.version_info > (3, 8):
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
init_multigpu_helper,
|
||||
skip_if_lt_x_gpu,
|
||||
get_timeout,
|
||||
#skip_if_rocm,
|
||||
with_dist_debug_levels,
|
||||
)
|
||||
else:
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
init_multigpu_helper,
|
||||
skip_if_lt_x_gpu,
|
||||
get_timeout,
|
||||
skip_if_rocm,
|
||||
with_dist_debug_levels,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
load_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
subtest,
|
||||
retry_on_connect_failures,
|
||||
instantiate_parametrized_tests,
|
||||
|
||||
)
|
||||
|
||||
onlyVacc = onlyPRIVATEUSE1
|
||||
|
||||
|
||||
class VaccTestBase(DeviceTypeTestBase):
|
||||
device_type = "vacc"
|
||||
|
||||
|
||||
if VaccTestBase not in cdt.device_type_test_bases:
|
||||
cdt.device_type_test_bases.append(VaccTestBase)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def freeze_rng_state():
|
||||
rng_state = torch.get_rng_state()
|
||||
yield
|
||||
torch.set_rng_state(rng_state)
|
||||
BIN
torch_vacc/testing/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
torch_vacc/testing/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
torch_vacc/testing/__pycache__/summarize_report.cpython-312.pyc
Normal file
BIN
torch_vacc/testing/__pycache__/summarize_report.cpython-312.pyc
Normal file
Binary file not shown.
103
torch_vacc/testing/summarize_report.py
Normal file
103
torch_vacc/testing/summarize_report.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Tool to summarize unit test XML reports, it summarize
|
||||
* number of tests, and failure/error/skipped
|
||||
* top 10 slowest tests
|
||||
|
||||
Usage:
|
||||
python -m torch_vacc.testing.summarize_report --report report.xml
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from xml.etree import ElementTree as ET
|
||||
import sys
|
||||
|
||||
from torch_vacc import set_global_log_level
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--report", type=str)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def summarize_testsuites(suites):
|
||||
summary = {
|
||||
"errors": int,
|
||||
"failures": int,
|
||||
"skipped": int,
|
||||
"skips": int,
|
||||
"tests": int,
|
||||
"time": float,
|
||||
}
|
||||
|
||||
attribs = [s.attrib for s in suites]
|
||||
for key in summary:
|
||||
summary[key] = sum(summary[key](a[key]) for a in attribs if key in a)
|
||||
assert not (summary["skipped"] and summary["skips"])
|
||||
if summary["skips"]:
|
||||
summary["skipped"] = summary["skips"]
|
||||
return summary
|
||||
|
||||
|
||||
def format_summary(summary):
|
||||
template = "Ran {tests} tests in {time:.3f}s (errors={errors}, failures={failures}, skipped={skipped})"
|
||||
msg = template.format(**summary)
|
||||
if summary["errors"] > 0 or summary["failures"] > 0:
|
||||
msg = "FAILED. " + msg
|
||||
return msg
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestCaseInfo:
|
||||
test_class_name: str
|
||||
test_name: str
|
||||
time: float
|
||||
timestamp: str
|
||||
success: bool
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.time < other.time
|
||||
|
||||
|
||||
def sort_cases_by_time(suites):
|
||||
test_cases = [
|
||||
TestCaseInfo(
|
||||
s.attrib["classname"],
|
||||
s.attrib["name"],
|
||||
s.attrib["time"],
|
||||
s.attrib["timestamp"],
|
||||
s.attrib.get("failure") is None,
|
||||
)
|
||||
for s in suites
|
||||
]
|
||||
test_cases.sort(reverse=True)
|
||||
return test_cases
|
||||
|
||||
|
||||
def read_report(fpath):
|
||||
with open(fpath) as report:
|
||||
try:
|
||||
report = ET.parse(report)
|
||||
except ET.ParseError:
|
||||
print(f"{sys.argv[0]}: Cannot parse file {fpath}", file=sys.stderr)
|
||||
return
|
||||
root = report.getroot()
|
||||
suites = [root] if root.tag == "testsuite" else root.findall("testsuite")
|
||||
summary = summarize_testsuites(suites)
|
||||
summary_msg = format_summary(summary)
|
||||
print(summary_msg)
|
||||
|
||||
for suite in suites:
|
||||
cases = sort_cases_by_time(suite.findall("testcase"))
|
||||
[print(case) for case in cases[:10]]
|
||||
|
||||
|
||||
def main():
|
||||
set_global_log_level("ERROR")
|
||||
args = parse_args()
|
||||
read_report(args.report)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user