[FOLLOWUP] Use base test to avoid patch everwhere (#1634)
### What this PR does / why we need it?
Use base test to avoid patch everwhere.
Followup here: https://github.com/vllm-project/vllm-ascend/pull/1566
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
ut ci passed
- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
This commit is contained in:
@@ -17,16 +17,13 @@ import unittest
|
|||||||
|
|
||||||
from vllm_ascend.utils import adapt_patch, register_ascend_customop
|
from vllm_ascend.utils import adapt_patch, register_ascend_customop
|
||||||
|
|
||||||
# fused moe ops test will hit the infer_schema error, we need add the patch
|
|
||||||
# here to make the test pass.
|
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
|
|
||||||
|
|
||||||
|
|
||||||
class TestBase(unittest.TestCase):
|
class TestBase(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def __init__(self, *args, **kwargs):
|
||||||
# adapt patch by default.
|
# adapt patch by default.
|
||||||
adapt_patch(True)
|
adapt_patch(True)
|
||||||
adapt_patch()
|
adapt_patch()
|
||||||
register_ascend_customop()
|
register_ascend_customop()
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
super(TestBase, self).__init__(*args, **kwargs)
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import unittest
|
|
||||||
import zlib
|
import zlib
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer,
|
from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer,
|
||||||
int32_hash)
|
int32_hash)
|
||||||
|
|
||||||
@@ -17,7 +17,7 @@ class MockSimplePipe:
|
|||||||
self.deallocate_buffer = MagicMock()
|
self.deallocate_buffer = MagicMock()
|
||||||
|
|
||||||
|
|
||||||
class TestSimpleBuffer(unittest.TestCase):
|
class TestSimpleBuffer(TestBase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.pipe = MockSimplePipe()
|
self.pipe = MockSimplePipe()
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
|
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
|
||||||
from vllm_ascend.distributed.kv_transfer.simple_connector import \
|
from vllm_ascend.distributed.kv_transfer.simple_connector import \
|
||||||
SimpleConnector
|
SimpleConnector
|
||||||
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
||||||
|
|
||||||
|
|
||||||
class TestSimpleConnector(unittest.TestCase):
|
class TestSimpleConnector(TestBase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.mock_pipe = MagicMock(spec=SimplePipe)
|
self.mock_pipe = MagicMock(spec=SimplePipe)
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
||||||
|
|
||||||
|
|
||||||
class TestSimplePipe(unittest.TestCase):
|
class TestSimplePipe(TestBase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _create_mock_config(self):
|
def _create_mock_config(self):
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import math
|
import math
|
||||||
import unittest
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -12,7 +11,7 @@ from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled,
|
|||||||
yarn_get_mscale)
|
yarn_get_mscale)
|
||||||
|
|
||||||
|
|
||||||
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
class TestCustomRotaryEmbeddingEnabled(TestBase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Common setup for tests
|
# Common setup for tests
|
||||||
@@ -67,7 +66,7 @@ class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
|||||||
self.assertFalse(result)
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
|
||||||
class TestRopeForwardOot(unittest.TestCase):
|
class TestRopeForwardOot(TestBase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# Common setup for tests
|
# Common setup for tests
|
||||||
@@ -262,7 +261,7 @@ class TestNativeRopeDeepseekForward(TestBase):
|
|||||||
assert k_pe.shape == key.shape
|
assert k_pe.shape == key.shape
|
||||||
|
|
||||||
|
|
||||||
class TestRotateHalf(unittest.TestCase):
|
class TestRotateHalf(TestBase):
|
||||||
|
|
||||||
def test_rotate_half_even_dim(self):
|
def test_rotate_half_even_dim(self):
|
||||||
# Test with even dimension
|
# Test with even dimension
|
||||||
@@ -272,7 +271,7 @@ class TestRotateHalf(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(result, expected))
|
self.assertTrue(torch.allclose(result, expected))
|
||||||
|
|
||||||
|
|
||||||
class TestYarnFindCorrectionDim(unittest.TestCase):
|
class TestYarnFindCorrectionDim(TestBase):
|
||||||
|
|
||||||
def test_basic_case(self):
|
def test_basic_case(self):
|
||||||
# Test with standard values
|
# Test with standard values
|
||||||
@@ -293,7 +292,7 @@ class TestYarnFindCorrectionDim(unittest.TestCase):
|
|||||||
self.assertTrue(torch.allclose(result, expected))
|
self.assertTrue(torch.allclose(result, expected))
|
||||||
|
|
||||||
|
|
||||||
class TestYarnGetMscale(unittest.TestCase):
|
class TestYarnGetMscale(TestBase):
|
||||||
|
|
||||||
def test_scale_less_than_or_equal_1(self):
|
def test_scale_less_than_or_equal_1(self):
|
||||||
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
|
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import unittest
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.quantization.func_wrapper import (wrapper_rmsnorm_forward_oot,
|
from vllm_ascend.quantization.func_wrapper import (wrapper_rmsnorm_forward_oot,
|
||||||
wrapper_rmsnorm_init)
|
wrapper_rmsnorm_init)
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ class MockRMSNorm:
|
|||||||
self.ignore_anti = extra_args.get('ignore_anti', True)
|
self.ignore_anti = extra_args.get('ignore_anti', True)
|
||||||
|
|
||||||
|
|
||||||
class TestFuncWrapper(unittest.TestCase):
|
class TestFuncWrapper(TestBase):
|
||||||
|
|
||||||
def test_wrapper_rmsnorm_init(self):
|
def test_wrapper_rmsnorm_init(self):
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
@@ -309,7 +308,7 @@ class TestUtils(TestBase):
|
|||||||
self.assertEqual(mock_customop.register_oot.call_count, 2)
|
self.assertEqual(mock_customop.register_oot.call_count, 2)
|
||||||
|
|
||||||
|
|
||||||
class TestProfileExecuteDuration(unittest.TestCase):
|
class TestProfileExecuteDuration(TestBase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
utils.ProfileExecuteDuration._instance = None
|
utils.ProfileExecuteDuration._instance = None
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||||
|
|
||||||
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
|
|
||||||
@@ -24,7 +23,7 @@ def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestInputBatch(unittest.TestCase):
|
class TestInputBatch(TestBase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.max_num_reqs = 10
|
self.max_num_reqs = 10
|
||||||
|
|||||||
Reference in New Issue
Block a user