From 0cc76860d57b65cbb1637f3346402b701c6e858d Mon Sep 17 00:00:00 2001 From: ZT-AIA <63220130+ZT-AIA@users.noreply.github.com> Date: Mon, 27 Apr 2026 21:48:54 +0800 Subject: [PATCH] [CI ][Misc] Add timeout check for custom op CI and optimize test parameters (#8755) ### What this PR does / why we need it? This PR introduces a mechanism to track test duration in `conftest.py` and skip subsequent tests in a file if a certain number of tests exceed a timeout threshold. This is intended to prevent CI hangs or long-running nightly tests. Additionally, it reduces the parameter space for `test_fused_qkvzba_split_reshape_cat.py` to further optimize CI runtime. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? nightly Signed-off-by: ZT-AIA <1028681969@qq.com> --- tests/e2e/nightly/single_node/ops/conftest.py | 37 +++++++++++++++++++ .../test_fused_qkvzba_split_reshape_cat.py | 10 ++--- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/tests/e2e/nightly/single_node/ops/conftest.py b/tests/e2e/nightly/single_node/ops/conftest.py index 681cfff8..cb177000 100644 --- a/tests/e2e/nightly/single_node/ops/conftest.py +++ b/tests/e2e/nightly/single_node/ops/conftest.py @@ -2,6 +2,43 @@ import time from datetime import datetime import pytest +DURATION_THRESHOLD = 120 +SLOW_COUNT_LIMIT = 5 + + +_per_file_slow_cases = {} +_current_file = None + + +def pytest_runtest_setup(item): + item.start_time = time.time() + + +def pytest_runtest_teardown(item, nextitem): + global _current_file + + file_path = item.fspath + duration = time.time() - item.start_time + + + if file_path not in _per_file_slow_cases: + _per_file_slow_cases[file_path] = 0 + + if duration > DURATION_THRESHOLD: + _per_file_slow_cases[file_path] += 1 + cnt = _per_file_slow_cases[file_path] + print(f" Detected that the test case took too long, ({cnt}/{SLOW_COUNT_LIMIT}):{duration:.2f}s") + + if cnt >= SLOW_COUNT_LIMIT: + print(f"\n The number of timeout test cases {file_path} ≥{SLOW_COUNT_LIMIT}\n") + _current_file = file_path + + +def pytest_runtest_call(item): + if _current_file == item.fspath: + print(f"CASE SKIP:{item.nodeid}") + pytest.skip(f"The use case takes too long.") + @pytest.hookimpl(tryfirst=True, hookwrapper=True) def pytest_runtest_makereport(item, call): diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_qkvzba_split_reshape_cat.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_qkvzba_split_reshape_cat.py index dc142427..9a971957 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_qkvzba_split_reshape_cat.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_fused_qkvzba_split_reshape_cat.py @@ -38,11 +38,11 @@ def validate_cmp(y_cal, y_ref, dtype, device='npu'): 'Invalid parameter \"dtype\" is found : {}'.format(dtype)) -@pytest.mark.parametrize("seq_len", [1, 16, 64, 128, 256, 1024, 2048, 3567]) -@pytest.mark.parametrize("num_heads_qk", [2, 4, 8, 16]) -@pytest.mark.parametrize("num_heads_v", [2, 4, 8]) -@pytest.mark.parametrize("head_qk_dim", [64, 128, 256]) -@pytest.mark.parametrize("head_v_dim", [64, 128]) +@pytest.mark.parametrize("seq_len", [1, 64, 1024, 2048]) +@pytest.mark.parametrize("num_heads_qk", [2, 4, 8]) +@pytest.mark.parametrize("num_heads_v", [8]) +@pytest.mark.parametrize("head_qk_dim", [256]) +@pytest.mark.parametrize("head_v_dim", [128]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) def test_fused_qkvzba_split_reshape_cat(