[FOLLOWUP] Use base test to avoid patch everwhere (#1634)
### What this PR does / why we need it?
Use base test to avoid patch everwhere.
Followup here: https://github.com/vllm-project/vllm-ascend/pull/1566
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
ut ci passed
- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import math
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
@@ -12,7 +11,7 @@ from vllm_ascend.ops.rotary_embedding import (custom_rotary_embedding_enabled,
|
||||
yarn_get_mscale)
|
||||
|
||||
|
||||
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
||||
class TestCustomRotaryEmbeddingEnabled(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
@@ -67,7 +66,7 @@ class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestRopeForwardOot(unittest.TestCase):
|
||||
class TestRopeForwardOot(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
@@ -262,7 +261,7 @@ class TestNativeRopeDeepseekForward(TestBase):
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
|
||||
class TestRotateHalf(unittest.TestCase):
|
||||
class TestRotateHalf(TestBase):
|
||||
|
||||
def test_rotate_half_even_dim(self):
|
||||
# Test with even dimension
|
||||
@@ -272,7 +271,7 @@ class TestRotateHalf(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
class TestYarnFindCorrectionDim(unittest.TestCase):
|
||||
class TestYarnFindCorrectionDim(TestBase):
|
||||
|
||||
def test_basic_case(self):
|
||||
# Test with standard values
|
||||
@@ -293,7 +292,7 @@ class TestYarnFindCorrectionDim(unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
class TestYarnGetMscale(unittest.TestCase):
|
||||
class TestYarnGetMscale(TestBase):
|
||||
|
||||
def test_scale_less_than_or_equal_1(self):
|
||||
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
|
||||
|
||||
Reference in New Issue
Block a user