### What this PR does / why we need it? Use Base test and cleanup all manaul patch code - Cleanup EPLB config to avoid tmp test file - Use BaseTest with global cache - Add license - Add a doc to setup unit test in local env ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
123 lines
4.8 KiB
Python
123 lines
4.8 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
|
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
|
|
W8A8Quantizer)
|
|
|
|
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
|
|
|
|
|
|
class TestGetQuantizer(TestBase):
|
|
|
|
def setUp(self):
|
|
# Setup common test fixtures
|
|
self.supported_types = {
|
|
'INT8': MagicMock(_instance=None),
|
|
'FP16': MagicMock(_instance=None),
|
|
'C8': MagicMock(_instance=None)
|
|
}
|
|
self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy()
|
|
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types)
|
|
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
|
self.mock_quant_config.quant_description = {"some_config": "value"}
|
|
|
|
def tearDown(self):
|
|
# Restore original supported types
|
|
SUPPORT_ASCEND_QUANTIZER_TYPE.clear()
|
|
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types)
|
|
|
|
def test_get_quantizer_fa(self):
|
|
"""Test successful quantizer retrieval for different cases."""
|
|
# Setup
|
|
quant_description = {'fa_quant_type': 'C8'}
|
|
prefix = '.attn'
|
|
expected_type = 'C8'
|
|
with patch.dict(
|
|
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
|
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
|
|
|
result = VLLMAscendQuantizer.get_quantizer(
|
|
quant_description,
|
|
prefix,
|
|
packed_modules_mapping={"some": "mapping"})
|
|
|
|
# Verify
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result,
|
|
self.supported_types[expected_type]._instance)
|
|
self.supported_types[expected_type].assert_called_once_with(
|
|
quant_description)
|
|
|
|
def test_get_quantizer_kv(self):
|
|
"""Test successful quantizer retrieval for different cases."""
|
|
# Setup
|
|
quant_description = {'kv_quant_type': 'C8'}
|
|
prefix = '.attn'
|
|
expected_type = 'C8'
|
|
with patch.dict(
|
|
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
|
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
|
|
|
result = VLLMAscendQuantizer.get_quantizer(
|
|
quant_description,
|
|
prefix,
|
|
packed_modules_mapping={"some": "mapping"})
|
|
|
|
# Verify
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result,
|
|
self.supported_types[expected_type]._instance)
|
|
self.supported_types[expected_type].assert_called_once_with(
|
|
quant_description)
|
|
|
|
def test_get_quantizer_linear(self):
|
|
"""Test successful quantizer retrieval for different cases."""
|
|
# Setup
|
|
quant_description = {'linear_type': 'INT8'}
|
|
prefix = 'nothing'
|
|
expected_type = 'INT8'
|
|
with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type',
|
|
return_value=expected_type), \
|
|
patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE):
|
|
|
|
result = VLLMAscendQuantizer.get_quantizer(
|
|
quant_description,
|
|
prefix,
|
|
packed_modules_mapping={"some": "mapping"})
|
|
|
|
# Verify
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result,
|
|
self.supported_types[expected_type]._instance)
|
|
self.supported_types[expected_type].assert_called_once_with(
|
|
quant_description)
|
|
|
|
|
|
class TestW8A8Quantizer(TestBase):
|
|
|
|
def setUp(self):
|
|
self.quantizer = W8A8Quantizer(quant_description={})
|
|
|
|
def test_build_linear_method(self):
|
|
with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod',
|
|
return_value=MagicMock()) as mock_linear:
|
|
result = self.quantizer.build_linear_method()
|
|
mock_linear.assert_called_once_with()
|
|
self.assertIsInstance(result, MagicMock)
|
|
|
|
def test_build_moe_method(self):
|
|
with patch(
|
|
'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod',
|
|
return_value=MagicMock()) as mock_linear:
|
|
result = self.quantizer.build_moe_method()
|
|
mock_linear.assert_called_once_with()
|
|
self.assertIsInstance(result, MagicMock)
|
|
|
|
def test_build_attention_method(self):
|
|
with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod',
|
|
return_value=MagicMock()) as mock_linear:
|
|
result = self.quantizer.build_attention_method()
|
|
mock_linear.assert_called_once_with()
|
|
self.assertIsInstance(result, MagicMock)
|