Files
xc-llm-ascend/tests/ut/quantization/test_quantizer.py
Zhu Yi Lin 6b80c5acba Fix W8A8 fused moe bug (#1529)
### What this PR does / why we need it?
1. drop some useless code for w8a8 fusedmoe
2. Add in8 kv cache check
3. Add more ut.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with new added test.

---------

Signed-off-by: zhuyilin <809721801@qq.com>
Signed-off-by: tianyitang <tangtianyi4@huawei.com>
Co-authored-by: tianyitang <tangtianyi4@huawei.com>
2025-07-02 16:40:51 +08:00

125 lines
4.9 KiB
Python

import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
from unittest.mock import MagicMock, patch
from tests.ut.base import TestBase
from vllm_ascend.quantization.quant_config import AscendQuantConfig
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
W8A8Quantizer)
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
class TestGetQuantizer(TestBase):
def setUp(self):
# Setup common test fixtures
self.supported_types = {
'INT8': MagicMock(_instance=None),
'FP16': MagicMock(_instance=None),
'C8': MagicMock(_instance=None)
}
self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy()
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types)
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
self.mock_quant_config.quant_description = {"some_config": "value"}
def tearDown(self):
# Restore original supported types
SUPPORT_ASCEND_QUANTIZER_TYPE.clear()
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types)
def test_get_quantizer_fa(self):
"""Test successful quantizer retrieval for different cases."""
# Setup
quant_description = {'fa_quant_type': 'C8'}
prefix = '.attn'
expected_type = 'C8'
with patch.dict(
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
SUPPORT_ASCEND_QUANTIZER_TYPE):
result = VLLMAscendQuantizer.get_quantizer(
quant_description,
prefix,
packed_modules_mapping={"some": "mapping"})
# Verify
self.assertIsNotNone(result)
self.assertEqual(result,
self.supported_types[expected_type]._instance)
self.supported_types[expected_type].assert_called_once_with(
quant_description)
def test_get_quantizer_kv(self):
"""Test successful quantizer retrieval for different cases."""
# Setup
quant_description = {'kv_quant_type': 'C8'}
prefix = '.attn'
expected_type = 'C8'
with patch.dict(
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
SUPPORT_ASCEND_QUANTIZER_TYPE):
result = VLLMAscendQuantizer.get_quantizer(
quant_description,
prefix,
packed_modules_mapping={"some": "mapping"})
# Verify
self.assertIsNotNone(result)
self.assertEqual(result,
self.supported_types[expected_type]._instance)
self.supported_types[expected_type].assert_called_once_with(
quant_description)
def test_get_quantizer_linear(self):
"""Test successful quantizer retrieval for different cases."""
# Setup
quant_description = {'linear_type': 'INT8'}
prefix = 'nothing'
expected_type = 'INT8'
with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type',
return_value=expected_type), \
patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE):
result = VLLMAscendQuantizer.get_quantizer(
quant_description,
prefix,
packed_modules_mapping={"some": "mapping"})
# Verify
self.assertIsNotNone(result)
self.assertEqual(result,
self.supported_types[expected_type]._instance)
self.supported_types[expected_type].assert_called_once_with(
quant_description)
class TestW8A8Quantizer(TestBase):
def setUp(self):
self.quantizer = W8A8Quantizer(quant_description={})
def test_build_linear_method(self):
with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod',
return_value=MagicMock()) as mock_linear:
result = self.quantizer.build_linear_method()
mock_linear.assert_called_once_with()
self.assertIsInstance(result, MagicMock)
def test_build_moe_method(self):
with patch(
'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod',
return_value=MagicMock()) as mock_linear:
result = self.quantizer.build_moe_method()
mock_linear.assert_called_once_with()
self.assertIsInstance(result, MagicMock)
def test_build_attention_method(self):
with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod',
return_value=MagicMock()) as mock_linear:
result = self.quantizer.build_attention_method()
mock_linear.assert_called_once_with()
self.assertIsInstance(result, MagicMock)