Files
enginex-mlu370-any2any/transformers/utils/scan_skipped_tests.py

200 lines
7.6 KiB
Python
Raw Permalink Normal View History

2025-10-09 16:47:16 +08:00
# coding=utf-8
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import re
from pathlib import Path
REPO_ROOT = Path().cwd()
COMMON_TEST_FILES: list[tuple[Path, str]] = [
(REPO_ROOT / "tests/test_modeling_common.py", "common"),
(REPO_ROOT / "tests/generation/test_utils.py", "GenerationMixin"),
]
MODELS_DIR = REPO_ROOT / "tests/models"
def get_common_tests(file_paths_with_origin: list[tuple[Path, str]]) -> dict[str, str]:
"""Extract all common test function names (e.g., 'test_forward')."""
tests_with_origin: dict[str, str] = {}
for file_path, origin_tag in file_paths_with_origin:
if not file_path.is_file():
continue
content = file_path.read_text(encoding="utf-8")
for test_name in re.findall(r"^\s*def\s+(test_[A-Za-z0-9_]+)", content, re.MULTILINE):
tests_with_origin[test_name] = origin_tag
return tests_with_origin
def get_models_and_test_files(models_dir: Path) -> tuple[list[str], list[Path]]:
if not models_dir.is_dir():
raise FileNotFoundError(f"Models directory not found at {models_dir}")
test_files: list[Path] = sorted(models_dir.rglob("test_modeling_*.py"))
model_names: list[str] = sorted({file_path.parent.name for file_path in test_files})
return model_names, test_files
def _extract_reason_from_decorators(decorators_block: str) -> str:
"""Extracts the reason string from a decorator block, if any."""
reason_match = re.search(r'reason\s*=\s*["\'](.*?)["\']', decorators_block)
if reason_match:
return reason_match.group(1)
reason_match = re.search(r'\((?:.*?,\s*)?["\'](.*?)["\']\)', decorators_block)
if reason_match:
return reason_match.group(1)
return decorators_block.strip().split("\n")[-1].strip()
def extract_test_info(file_content: str) -> dict[str, tuple[str, str]]:
"""
Parse a test file once and return a mapping of test functions to their
status and skip reason, e.g. {'test_forward': ('SKIPPED', 'too slow')}.
"""
result: dict[str, tuple[str, str]] = {}
pattern = re.compile(r"((?:^\s*@.*?\n)*?)^\s*def\s+(test_[A-Za-z0-9_]+)\b", re.MULTILINE)
for decorators_block, test_name in pattern.findall(file_content):
if "skip" in decorators_block:
result[test_name] = ("SKIPPED", _extract_reason_from_decorators(decorators_block))
else:
result[test_name] = ("RAN", "")
return result
def build_model_overrides(model_test_files: list[Path]) -> dict[str, dict[str, tuple[str, str]]]:
"""Return *model_name → {test_name → (status, reason)}* mapping."""
model_overrides: dict[str, dict[str, tuple[str, str]]] = {}
for file_path in model_test_files:
model_name = file_path.parent.name
file_content = file_path.read_text(encoding="utf-8")
model_overrides.setdefault(model_name, {}).update(extract_test_info(file_content))
return model_overrides
def save_json(obj: dict, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(obj, indent=2), encoding="utf-8")
def summarize_single_test(
test_name: str,
model_names: list[str],
model_overrides: dict[str, dict[str, tuple[str, str]]],
) -> dict[str, object]:
"""Print a concise terminal summary for *test_name* and return the raw data."""
models_ran, models_skipped, reasons_for_skipping = [], [], []
for model_name in model_names:
status, reason = model_overrides.get(model_name, {}).get(test_name, ("RAN", ""))
if status == "SKIPPED":
models_skipped.append(model_name)
reasons_for_skipping.append(f"{model_name}: {reason}")
else:
models_ran.append(model_name)
total_models = len(model_names)
skipped_ratio = len(models_skipped) / total_models if total_models else 0.0
print(f"\n== {test_name} ==")
print(f"Ran : {len(models_ran)}/{total_models}")
print(f"Skipped : {len(models_skipped)}/{total_models} ({skipped_ratio:.1%})")
for reason_entry in reasons_for_skipping[:10]:
print(f" - {reason_entry}")
if len(reasons_for_skipping) > 10:
print(" - ...")
return {
"models_ran": sorted(models_ran),
"models_skipped": sorted(models_skipped),
"skipped_proportion": round(skipped_ratio, 4),
"reasons_skipped": sorted(reasons_for_skipping),
}
def summarize_all_tests(
tests_with_origin: dict[str, str],
model_names: list[str],
model_overrides: dict[str, dict[str, tuple[str, str]]],
) -> dict[str, object]:
"""Return aggregated data for every discovered common test."""
results: dict[str, object] = {}
total_models = len(model_names)
test_names = list(tests_with_origin)
print(f"📝 Aggregating {len(test_names)} tests...")
for index, test_fn in enumerate(test_names, 1):
print(f" ({index}/{len(test_names)}) {test_fn}", end="\r")
models_ran, models_skipped, reasons_for_skipping = [], [], []
for model_name in model_names:
status, reason = model_overrides.get(model_name, {}).get(test_fn, ("RAN", ""))
if status == "SKIPPED":
models_skipped.append(model_name)
reasons_for_skipping.append(f"{model_name}: {reason}")
else:
models_ran.append(model_name)
skipped_ratio = len(models_skipped) / total_models if total_models else 0.0
results[test_fn] = {
"origin": tests_with_origin[test_fn],
"models_ran": sorted(models_ran),
"models_skipped": sorted(models_skipped),
"skipped_proportion": round(skipped_ratio, 4),
"reasons_skipped": sorted(reasons_for_skipping),
}
print("\n✅ Scan complete.")
return results
def main() -> None:
parser = argparse.ArgumentParser(
description="Scan model tests for overridden or skipped common or generate tests.",
)
parser.add_argument(
"--output_dir",
default=".",
help="Directory for JSON output (default: %(default)s)",
)
parser.add_argument(
"--test_method_name",
help="Scan only this test method (singletest mode)",
)
args = parser.parse_args()
output_dir = Path(args.output_dir).expanduser()
test_method_name = args.test_method_name
tests_with_origin = get_common_tests(COMMON_TEST_FILES)
if test_method_name:
tests_with_origin = {test_method_name: tests_with_origin.get(test_method_name, "unknown")}
model_names, model_test_files = get_models_and_test_files(MODELS_DIR)
print(f"🔬 Parsing {len(model_test_files)} model test files once each...")
model_overrides = build_model_overrides(model_test_files)
if test_method_name:
data = summarize_single_test(test_method_name, model_names, model_overrides)
json_path = output_dir / f"scan_{test_method_name}.json"
else:
data = summarize_all_tests(tests_with_origin, model_names, model_overrides)
json_path = output_dir / "all_tests_scan_result.json"
save_json(data, json_path)
print(f"\n📄 JSON saved to {json_path.resolve()}")
if __name__ == "__main__":
main()