63 lines
2.6 KiB
Python
63 lines
2.6 KiB
Python
import types
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP,
|
|
get_quant_method)
|
|
|
|
|
|
class TestGetQuantMethod(TestBase):
|
|
|
|
def setUp(self):
|
|
self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy(
|
|
)
|
|
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
|
for layer_type in layer_map.keys():
|
|
ASCEND_QUANTIZATION_METHOD_MAP[quant_type][
|
|
layer_type] = types.new_class(f"{quant_type}_{layer_type}")
|
|
|
|
def tearDown(self):
|
|
# Restore original map
|
|
ASCEND_QUANTIZATION_METHOD_MAP.clear()
|
|
ASCEND_QUANTIZATION_METHOD_MAP.update(
|
|
self.original_quantization_method_map)
|
|
|
|
def test_linear_quant_methods(self):
|
|
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
|
if "linear" in layer_map.keys():
|
|
prefix = "linear_layer"
|
|
cls = layer_map["linear"]
|
|
method = get_quant_method({"linear_layer.weight": quant_type},
|
|
prefix, "linear")
|
|
self.assertIsInstance(method, cls)
|
|
|
|
def test_moe_quant_methods(self):
|
|
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
|
if "moe" in layer_map.keys():
|
|
prefix = "layer"
|
|
cls = layer_map["moe"]
|
|
method = get_quant_method({"layer.weight": quant_type}, prefix,
|
|
"moe")
|
|
self.assertIsInstance(method, cls)
|
|
|
|
def test_with_fa_quant_type(self):
|
|
quant_description = {"fa_quant_type": "C8"}
|
|
method = get_quant_method(quant_description, ".attn", "attention")
|
|
self.assertIsInstance(
|
|
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
|
|
|
|
def test_with_kv_quant_type(self):
|
|
quant_description = {"kv_quant_type": "C8"}
|
|
method = get_quant_method(quant_description, ".attn", "attention")
|
|
self.assertIsInstance(
|
|
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
|
|
|
|
def test_invalid_layer_type(self):
|
|
quant_description = {"linear_layer.weight": "W8A8"}
|
|
with self.assertRaises(NotImplementedError):
|
|
get_quant_method(quant_description, "linear_layer", "unsupported")
|
|
|
|
def test_invalid_quant_type(self):
|
|
quant_description = {"linear_layer.weight": "UNKNOWN"}
|
|
with self.assertRaises(NotImplementedError):
|
|
get_quant_method(quant_description, "linear_layer", "linear")
|