diff --git a/tests/ut/base.py b/tests/ut/base.py index 8b396d6..6bdf1f4 100644 --- a/tests/ut/base.py +++ b/tests/ut/base.py @@ -17,16 +17,13 @@ import unittest from vllm_ascend.utils import adapt_patch, register_ascend_customop -# fused moe ops test will hit the infer_schema error, we need add the patch -# here to make the test pass. -import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa - class TestBase(unittest.TestCase): - def setUp(self): + def __init__(self, *args, **kwargs): # adapt patch by default. adapt_patch(True) adapt_patch() register_ascend_customop() super().setUp() + super(TestBase, self).__init__(*args, **kwargs) diff --git a/tests/ut/distributed/kv_transfer/test_simple_buffer.py b/tests/ut/distributed/kv_transfer/test_simple_buffer.py index 6f90df9..1ff81bc 100644 --- a/tests/ut/distributed/kv_transfer/test_simple_buffer.py +++ b/tests/ut/distributed/kv_transfer/test_simple_buffer.py @@ -1,9 +1,9 @@ -import unittest import zlib from unittest.mock import MagicMock import torch +from tests.ut.base import TestBase from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer, int32_hash) @@ -17,7 +17,7 @@ class MockSimplePipe: self.deallocate_buffer = MagicMock() -class TestSimpleBuffer(unittest.TestCase): +class TestSimpleBuffer(TestBase): def setUp(self): self.pipe = MockSimplePipe() diff --git a/tests/ut/distributed/kv_transfer/test_simple_connector.py b/tests/ut/distributed/kv_transfer/test_simple_connector.py index ac6c4d4..2c81943 100644 --- a/tests/ut/distributed/kv_transfer/test_simple_connector.py +++ b/tests/ut/distributed/kv_transfer/test_simple_connector.py @@ -1,17 +1,17 @@ -import unittest from unittest.mock import MagicMock, patch import torch from vllm.config import VllmConfig from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +from tests.ut.base import TestBase from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer from vllm_ascend.distributed.kv_transfer.simple_connector import \ SimpleConnector from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe -class TestSimpleConnector(unittest.TestCase): +class TestSimpleConnector(TestBase): def setUp(self): self.mock_pipe = MagicMock(spec=SimplePipe) diff --git a/tests/ut/distributed/kv_transfer/test_simple_pipe.py b/tests/ut/distributed/kv_transfer/test_simple_pipe.py index efd6edd..ccc984b 100644 --- a/tests/ut/distributed/kv_transfer/test_simple_pipe.py +++ b/tests/ut/distributed/kv_transfer/test_simple_pipe.py @@ -1,12 +1,12 @@ -import unittest from unittest.mock import MagicMock, patch import torch +from tests.ut.base import TestBase from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe -class TestSimplePipe(unittest.TestCase): +class TestSimplePipe(TestBase): @classmethod def _create_mock_config(self): diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index 91c2ad4..3b388e0 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -1,5 +1,4 @@ import math -import unittest from unittest.mock import MagicMock, patch import torch @@ -12,7 +11,7 @@ from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled, yarn_get_mscale) -class TestCustomRotaryEmbeddingEnabled(unittest.TestCase): +class TestCustomRotaryEmbeddingEnabled(TestBase): def setUp(self): # Common setup for tests @@ -67,7 +66,7 @@ class TestCustomRotaryEmbeddingEnabled(unittest.TestCase): self.assertFalse(result) -class TestRopeForwardOot(unittest.TestCase): +class TestRopeForwardOot(TestBase): def setUp(self): # Common setup for tests @@ -262,7 +261,7 @@ class TestNativeRopeDeepseekForward(TestBase): assert k_pe.shape == key.shape -class TestRotateHalf(unittest.TestCase): +class TestRotateHalf(TestBase): def test_rotate_half_even_dim(self): # Test with even dimension @@ -272,7 +271,7 @@ class TestRotateHalf(unittest.TestCase): self.assertTrue(torch.allclose(result, expected)) -class TestYarnFindCorrectionDim(unittest.TestCase): +class TestYarnFindCorrectionDim(TestBase): def test_basic_case(self): # Test with standard values @@ -293,7 +292,7 @@ class TestYarnFindCorrectionDim(unittest.TestCase): self.assertTrue(torch.allclose(result, expected)) -class TestYarnGetMscale(unittest.TestCase): +class TestYarnGetMscale(TestBase): def test_scale_less_than_or_equal_1(self): self.assertEqual(yarn_get_mscale(scale=0.5), 1.0) diff --git a/tests/ut/quantization/test_func_wrapper.py b/tests/ut/quantization/test_func_wrapper.py index 2f9bd89..5020f80 100644 --- a/tests/ut/quantization/test_func_wrapper.py +++ b/tests/ut/quantization/test_func_wrapper.py @@ -1,8 +1,8 @@ -import unittest from unittest.mock import patch import torch +from tests.ut.base import TestBase from vllm_ascend.quantization.func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init) @@ -20,7 +20,7 @@ class MockRMSNorm: self.ignore_anti = extra_args.get('ignore_anti', True) -class TestFuncWrapper(unittest.TestCase): +class TestFuncWrapper(TestBase): def test_wrapper_rmsnorm_init(self): diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 6eb96df..7716849 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -15,7 +15,6 @@ import math import os -import unittest from threading import Lock from unittest import mock @@ -309,7 +308,7 @@ class TestUtils(TestBase): self.assertEqual(mock_customop.register_oot.call_count, 2) -class TestProfileExecuteDuration(unittest.TestCase): +class TestProfileExecuteDuration(TestBase): def setUp(self): utils.ProfileExecuteDuration._instance = None diff --git a/tests/ut/worker/test_input_batch.py b/tests/ut/worker/test_input_batch.py index cbfd67f..7baee71 100644 --- a/tests/ut/worker/test_input_batch.py +++ b/tests/ut/worker/test_input_batch.py @@ -1,11 +1,10 @@ -import unittest - import numpy as np import torch from vllm.sampling_params import SamplingParams from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import MultiGroupBlockTable +from tests.ut.base import TestBase from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch @@ -24,7 +23,7 @@ def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]): ) -class TestInputBatch(unittest.TestCase): +class TestInputBatch(TestBase): def setUp(self): self.max_num_reqs = 10