[CI ][Misc] Add timeout check for custom op CI and optimize test parameters (#8755)
### What this PR does / why we need it? This PR introduces a mechanism to track test duration in `conftest.py` and skip subsequent tests in a file if a certain number of tests exceed a timeout threshold. This is intended to prevent CI hangs or long-running nightly tests. Additionally, it reduces the parameter space for `test_fused_qkvzba_split_reshape_cat.py` to further optimize CI runtime. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? nightly Signed-off-by: ZT-AIA <1028681969@qq.com>
This commit is contained in:
@@ -2,6 +2,43 @@ import time
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
DURATION_THRESHOLD = 120
|
||||||
|
SLOW_COUNT_LIMIT = 5
|
||||||
|
|
||||||
|
|
||||||
|
_per_file_slow_cases = {}
|
||||||
|
_current_file = None
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_runtest_setup(item):
|
||||||
|
item.start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_runtest_teardown(item, nextitem):
|
||||||
|
global _current_file
|
||||||
|
|
||||||
|
file_path = item.fspath
|
||||||
|
duration = time.time() - item.start_time
|
||||||
|
|
||||||
|
|
||||||
|
if file_path not in _per_file_slow_cases:
|
||||||
|
_per_file_slow_cases[file_path] = 0
|
||||||
|
|
||||||
|
if duration > DURATION_THRESHOLD:
|
||||||
|
_per_file_slow_cases[file_path] += 1
|
||||||
|
cnt = _per_file_slow_cases[file_path]
|
||||||
|
print(f" Detected that the test case took too long, ({cnt}/{SLOW_COUNT_LIMIT}):{duration:.2f}s")
|
||||||
|
|
||||||
|
if cnt >= SLOW_COUNT_LIMIT:
|
||||||
|
print(f"\n The number of timeout test cases {file_path} ≥{SLOW_COUNT_LIMIT}\n")
|
||||||
|
_current_file = file_path
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_runtest_call(item):
|
||||||
|
if _current_file == item.fspath:
|
||||||
|
print(f"CASE SKIP:{item.nodeid}")
|
||||||
|
pytest.skip(f"The use case takes too long.")
|
||||||
|
|
||||||
|
|
||||||
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
@pytest.hookimpl(tryfirst=True, hookwrapper=True)
|
||||||
def pytest_runtest_makereport(item, call):
|
def pytest_runtest_makereport(item, call):
|
||||||
|
|||||||
@@ -38,11 +38,11 @@ def validate_cmp(y_cal, y_ref, dtype, device='npu'):
|
|||||||
'Invalid parameter \"dtype\" is found : {}'.format(dtype))
|
'Invalid parameter \"dtype\" is found : {}'.format(dtype))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seq_len", [1, 16, 64, 128, 256, 1024, 2048, 3567])
|
@pytest.mark.parametrize("seq_len", [1, 64, 1024, 2048])
|
||||||
@pytest.mark.parametrize("num_heads_qk", [2, 4, 8, 16])
|
@pytest.mark.parametrize("num_heads_qk", [2, 4, 8])
|
||||||
@pytest.mark.parametrize("num_heads_v", [2, 4, 8])
|
@pytest.mark.parametrize("num_heads_v", [8])
|
||||||
@pytest.mark.parametrize("head_qk_dim", [64, 128, 256])
|
@pytest.mark.parametrize("head_qk_dim", [256])
|
||||||
@pytest.mark.parametrize("head_v_dim", [64, 128])
|
@pytest.mark.parametrize("head_v_dim", [128])
|
||||||
@pytest.mark.parametrize("dtype",
|
@pytest.mark.parametrize("dtype",
|
||||||
[torch.float32, torch.float16, torch.bfloat16])
|
[torch.float32, torch.float16, torch.bfloat16])
|
||||||
def test_fused_qkvzba_split_reshape_cat(
|
def test_fused_qkvzba_split_reshape_cat(
|
||||||
|
|||||||
Reference in New Issue
Block a user