This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

@@ -0,0 +1,62 @@
from contextlib import contextmanager
import os
import sys
import torch
from torch.testing import make_tensor
from functools import partial, wraps
import torch.testing._internal.common_device_type as cdt
from torch.testing._internal.common_device_type import (
DeviceTypeTestBase,
dtypes,
instantiate_device_type_tests,
onlyOn,
onlyPRIVATEUSE1,
ops,
)
if sys.version_info > (3, 8):
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
init_multigpu_helper,
skip_if_lt_x_gpu,
get_timeout,
#skip_if_rocm,
with_dist_debug_levels,
)
else:
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
init_multigpu_helper,
skip_if_lt_x_gpu,
get_timeout,
skip_if_rocm,
with_dist_debug_levels,
)
from torch.testing._internal.common_utils import (
TestCase,
load_tests,
parametrize,
run_tests,
subtest,
retry_on_connect_failures,
instantiate_parametrized_tests,
)
onlyVacc = onlyPRIVATEUSE1
class VaccTestBase(DeviceTypeTestBase):
device_type = "vacc"
if VaccTestBase not in cdt.device_type_test_bases:
cdt.device_type_test_bases.append(VaccTestBase)
@contextmanager
def freeze_rng_state():
rng_state = torch.get_rng_state()
yield
torch.set_rng_state(rng_state)

View File

@@ -0,0 +1,103 @@
"""
Tool to summarize unit test XML reports, it summarize
* number of tests, and failure/error/skipped
* top 10 slowest tests
Usage:
python -m torch_vacc.testing.summarize_report --report report.xml
"""
import argparse
from dataclasses import dataclass
from xml.etree import ElementTree as ET
import sys
from torch_vacc import set_global_log_level
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--report", type=str)
return parser.parse_args()
def summarize_testsuites(suites):
summary = {
"errors": int,
"failures": int,
"skipped": int,
"skips": int,
"tests": int,
"time": float,
}
attribs = [s.attrib for s in suites]
for key in summary:
summary[key] = sum(summary[key](a[key]) for a in attribs if key in a)
assert not (summary["skipped"] and summary["skips"])
if summary["skips"]:
summary["skipped"] = summary["skips"]
return summary
def format_summary(summary):
template = "Ran {tests} tests in {time:.3f}s (errors={errors}, failures={failures}, skipped={skipped})"
msg = template.format(**summary)
if summary["errors"] > 0 or summary["failures"] > 0:
msg = "FAILED. " + msg
return msg
@dataclass
class TestCaseInfo:
test_class_name: str
test_name: str
time: float
timestamp: str
success: bool
def __lt__(self, other):
return self.time < other.time
def sort_cases_by_time(suites):
test_cases = [
TestCaseInfo(
s.attrib["classname"],
s.attrib["name"],
s.attrib["time"],
s.attrib["timestamp"],
s.attrib.get("failure") is None,
)
for s in suites
]
test_cases.sort(reverse=True)
return test_cases
def read_report(fpath):
with open(fpath) as report:
try:
report = ET.parse(report)
except ET.ParseError:
print(f"{sys.argv[0]}: Cannot parse file {fpath}", file=sys.stderr)
return
root = report.getroot()
suites = [root] if root.tag == "testsuite" else root.findall("testsuite")
summary = summarize_testsuites(suites)
summary_msg = format_summary(summary)
print(summary_msg)
for suite in suites:
cases = sort_cases_by_time(suite.findall("testcase"))
[print(case) for case in cases[:10]]
def main():
set_global_log_level("ERROR")
args = parse_args()
read_report(args.report)
if __name__ == "__main__":
main()