Files
xc-llm-ascend/tests/ut/quantization/test_utils.py

63 lines
2.6 KiB
Python
Raw Permalink Normal View History

import types
from tests.ut.base import TestBase
from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP,
get_quant_method)
class TestGetQuantMethod(TestBase):
def setUp(self):
self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy(
)
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
for layer_type in layer_map.keys():
ASCEND_QUANTIZATION_METHOD_MAP[quant_type][
layer_type] = types.new_class(f"{quant_type}_{layer_type}")
def tearDown(self):
# Restore original map
ASCEND_QUANTIZATION_METHOD_MAP.clear()
ASCEND_QUANTIZATION_METHOD_MAP.update(
self.original_quantization_method_map)
def test_linear_quant_methods(self):
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
if "linear" in layer_map.keys():
prefix = "linear_layer"
cls = layer_map["linear"]
method = get_quant_method({"linear_layer.weight": quant_type},
prefix, "linear")
self.assertIsInstance(method, cls)
def test_moe_quant_methods(self):
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
if "moe" in layer_map.keys():
prefix = "layer"
cls = layer_map["moe"]
method = get_quant_method({"layer.weight": quant_type}, prefix,
"moe")
self.assertIsInstance(method, cls)
def test_with_fa_quant_type(self):
quant_description = {"fa_quant_type": "C8"}
method = get_quant_method(quant_description, ".attn", "attention")
self.assertIsInstance(
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
def test_with_kv_quant_type(self):
quant_description = {"kv_quant_type": "C8"}
method = get_quant_method(quant_description, ".attn", "attention")
self.assertIsInstance(
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
def test_invalid_layer_type(self):
quant_description = {"linear_layer.weight": "W8A8"}
with self.assertRaises(NotImplementedError):
get_quant_method(quant_description, "linear_layer", "unsupported")
def test_invalid_quant_type(self):
quant_description = {"linear_layer.weight": "UNKNOWN"}
with self.assertRaises(NotImplementedError):
get_quant_method(quant_description, "linear_layer", "linear")