[1/N][Refactor][Quantization] remove redundant quantizer class (#2680)
### What this PR does / why we need it?
AscendQuantizer/LLMQuantizer class is used to select quant method based
on quant config and some other arguments,
but it is more simple and clean replacing these classes with map. So i
remove them.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
ut and e2e test
- vLLM version: v0.10.1.1
- vLLM main:
6997a25ac6
Signed-off-by: 22dimensions <waitingwind@foxmail.com>
This commit is contained in:
@@ -156,33 +156,22 @@ class TestAscendKVCacheMethod(TestBase):
|
||||
def setUp(self):
|
||||
# Setup common test fixtures
|
||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
||||
self.mock_quant_config.quant_description = {"some_config": "value"}
|
||||
self.prefix = "attention_layer"
|
||||
self.mock_quant_config.quant_description = {"kv_quant_type": "C8"}
|
||||
self.prefix = "layer.attn"
|
||||
|
||||
# Mock the quantizer and quant_method
|
||||
self.mock_quantizer = MagicMock()
|
||||
# Mock quant_method
|
||||
self.mock_quant_method = MagicMock()
|
||||
|
||||
# Patch the AscendQuantizer
|
||||
self.quantizer_patcher = patch(
|
||||
'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer',
|
||||
return_value=self.mock_quantizer)
|
||||
self.mock_get_quantizer = self.quantizer_patcher.start()
|
||||
|
||||
self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method
|
||||
self.patcher = patch(
|
||||
'vllm_ascend.quantization.quant_config.get_quant_method')
|
||||
self.mock_get_quant_method = self.patcher.start()
|
||||
self.mock_get_quant_method.return_value = self.mock_quant_method
|
||||
|
||||
# Create instance
|
||||
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
|
||||
self.prefix)
|
||||
|
||||
def tearDown(self):
|
||||
self.quantizer_patcher.stop()
|
||||
|
||||
def test_init(self):
|
||||
"""Test initialization with proper quantizer setup."""
|
||||
self.mock_get_quantizer.assert_called_once_with(
|
||||
self.mock_quant_config.quant_description, self.prefix)
|
||||
self.mock_quantizer.build_attention_method.assert_called_once()
|
||||
self.patcher.stop()
|
||||
|
||||
def test_create_weights(self):
|
||||
"""Test create_weights delegates to quant_method."""
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
|
||||
W4A8DYNAMICQuantizer,
|
||||
W8A8Quantizer)
|
||||
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
|
||||
|
||||
|
||||
class TestGetQuantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Setup common test fixtures
|
||||
self.supported_types = {
|
||||
'INT8': MagicMock(_instance=None),
|
||||
'FP16': MagicMock(_instance=None),
|
||||
'C8': MagicMock(_instance=None)
|
||||
}
|
||||
self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy()
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types)
|
||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
||||
self.mock_quant_config.quant_description = {"some_config": "value"}
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original supported types
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.clear()
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types)
|
||||
|
||||
def test_get_quantizer_fa(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'fa_quant_type': 'C8'}
|
||||
prefix = '.attn'
|
||||
expected_type = 'C8'
|
||||
with patch.dict(
|
||||
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
def test_get_quantizer_kv(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'kv_quant_type': 'C8'}
|
||||
prefix = '.attn'
|
||||
expected_type = 'C8'
|
||||
with patch.dict(
|
||||
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
def test_get_quantizer_linear(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'linear_type': 'INT8'}
|
||||
prefix = 'nothing'
|
||||
expected_type = 'INT8'
|
||||
with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type',
|
||||
return_value=expected_type), \
|
||||
patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
|
||||
class TestW8A8Quantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.quantizer = W8A8Quantizer(quant_description={})
|
||||
|
||||
def test_build_linear_method(self):
|
||||
with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_linear_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_moe_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_moe_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_attention_method(self):
|
||||
with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_attention_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
|
||||
class TestW4A8DYNAMICQuantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.quantizer = W4A8DYNAMICQuantizer(quant_description={})
|
||||
|
||||
def test_build_linear_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicLinearMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_linear_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_moe_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicFusedMoEMethod',
|
||||
return_value=MagicMock()) as mock_fused_moe:
|
||||
result = self.quantizer.build_moe_method()
|
||||
mock_fused_moe.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
62
tests/ut/quantization/test_utils.py
Normal file
62
tests/ut/quantization/test_utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import types
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP,
|
||||
get_quant_method)
|
||||
|
||||
|
||||
class TestGetQuantMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy(
|
||||
)
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
for layer_type in layer_map.keys():
|
||||
ASCEND_QUANTIZATION_METHOD_MAP[quant_type][
|
||||
layer_type] = types.new_class(f"{quant_type}_{layer_type}")
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original map
|
||||
ASCEND_QUANTIZATION_METHOD_MAP.clear()
|
||||
ASCEND_QUANTIZATION_METHOD_MAP.update(
|
||||
self.original_quantization_method_map)
|
||||
|
||||
def test_linear_quant_methods(self):
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
if "linear" in layer_map.keys():
|
||||
prefix = "linear_layer"
|
||||
cls = layer_map["linear"]
|
||||
method = get_quant_method({"linear_layer.weight": quant_type},
|
||||
prefix, "linear")
|
||||
self.assertIsInstance(method, cls)
|
||||
|
||||
def test_moe_quant_methods(self):
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
if "moe" in layer_map.keys():
|
||||
prefix = "layer"
|
||||
cls = layer_map["moe"]
|
||||
method = get_quant_method({"layer.weight": quant_type}, prefix,
|
||||
"moe")
|
||||
self.assertIsInstance(method, cls)
|
||||
|
||||
def test_with_fa_quant_type(self):
|
||||
quant_description = {"fa_quant_type": "C8"}
|
||||
method = get_quant_method(quant_description, ".attn", "attention")
|
||||
self.assertIsInstance(
|
||||
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
|
||||
|
||||
def test_with_kv_quant_type(self):
|
||||
quant_description = {"kv_quant_type": "C8"}
|
||||
method = get_quant_method(quant_description, ".attn", "attention")
|
||||
self.assertIsInstance(
|
||||
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
|
||||
|
||||
def test_invalid_layer_type(self):
|
||||
quant_description = {"linear_layer.weight": "W8A8"}
|
||||
with self.assertRaises(NotImplementedError):
|
||||
get_quant_method(quant_description, "linear_layer", "unsupported")
|
||||
|
||||
def test_invalid_quant_type(self):
|
||||
quant_description = {"linear_layer.weight": "UNKNOWN"}
|
||||
with self.assertRaises(NotImplementedError):
|
||||
get_quant_method(quant_description, "linear_layer", "linear")
|
||||
Reference in New Issue
Block a user