diff --git a/tests/ut/quantization/test_quant_config.py b/tests/ut/quantization/test_quant_config.py index 7529fea..fa5d13e 100644 --- a/tests/ut/quantization/test_quant_config.py +++ b/tests/ut/quantization/test_quant_config.py @@ -156,33 +156,22 @@ class TestAscendKVCacheMethod(TestBase): def setUp(self): # Setup common test fixtures self.mock_quant_config = MagicMock(spec=AscendQuantConfig) - self.mock_quant_config.quant_description = {"some_config": "value"} - self.prefix = "attention_layer" + self.mock_quant_config.quant_description = {"kv_quant_type": "C8"} + self.prefix = "layer.attn" - # Mock the quantizer and quant_method - self.mock_quantizer = MagicMock() + # Mock quant_method self.mock_quant_method = MagicMock() - - # Patch the AscendQuantizer - self.quantizer_patcher = patch( - 'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer', - return_value=self.mock_quantizer) - self.mock_get_quantizer = self.quantizer_patcher.start() - - self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method + self.patcher = patch( + 'vllm_ascend.quantization.quant_config.get_quant_method') + self.mock_get_quant_method = self.patcher.start() + self.mock_get_quant_method.return_value = self.mock_quant_method # Create instance self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config, self.prefix) def tearDown(self): - self.quantizer_patcher.stop() - - def test_init(self): - """Test initialization with proper quantizer setup.""" - self.mock_get_quantizer.assert_called_once_with( - self.mock_quant_config.quant_description, self.prefix) - self.mock_quantizer.build_attention_method.assert_called_once() + self.patcher.stop() def test_create_weights(self): """Test create_weights delegates to quant_method.""" diff --git a/tests/ut/quantization/test_quantizer.py b/tests/ut/quantization/test_quantizer.py deleted file mode 100644 index a51faee..0000000 --- a/tests/ut/quantization/test_quantizer.py +++ /dev/null @@ -1,145 +0,0 @@ -from unittest.mock import MagicMock, patch - -from tests.ut.base import TestBase -from vllm_ascend.quantization.quant_config import AscendQuantConfig -from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer, - W4A8DYNAMICQuantizer, - W8A8Quantizer) - -SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"} - - -class TestGetQuantizer(TestBase): - - def setUp(self): - # Setup common test fixtures - self.supported_types = { - 'INT8': MagicMock(_instance=None), - 'FP16': MagicMock(_instance=None), - 'C8': MagicMock(_instance=None) - } - self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy() - SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types) - self.mock_quant_config = MagicMock(spec=AscendQuantConfig) - self.mock_quant_config.quant_description = {"some_config": "value"} - - def tearDown(self): - # Restore original supported types - SUPPORT_ASCEND_QUANTIZER_TYPE.clear() - SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types) - - def test_get_quantizer_fa(self): - """Test successful quantizer retrieval for different cases.""" - # Setup - quant_description = {'fa_quant_type': 'C8'} - prefix = '.attn' - expected_type = 'C8' - with patch.dict( - 'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', - SUPPORT_ASCEND_QUANTIZER_TYPE): - - result = VLLMAscendQuantizer.get_quantizer( - quant_description, - prefix, - packed_modules_mapping={"some": "mapping"}) - - # Verify - self.assertIsNotNone(result) - self.assertEqual(result, - self.supported_types[expected_type]._instance) - self.supported_types[expected_type].assert_called_once_with( - quant_description) - - def test_get_quantizer_kv(self): - """Test successful quantizer retrieval for different cases.""" - # Setup - quant_description = {'kv_quant_type': 'C8'} - prefix = '.attn' - expected_type = 'C8' - with patch.dict( - 'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', - SUPPORT_ASCEND_QUANTIZER_TYPE): - - result = VLLMAscendQuantizer.get_quantizer( - quant_description, - prefix, - packed_modules_mapping={"some": "mapping"}) - - # Verify - self.assertIsNotNone(result) - self.assertEqual(result, - self.supported_types[expected_type]._instance) - self.supported_types[expected_type].assert_called_once_with( - quant_description) - - def test_get_quantizer_linear(self): - """Test successful quantizer retrieval for different cases.""" - # Setup - quant_description = {'linear_type': 'INT8'} - prefix = 'nothing' - expected_type = 'INT8' - with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type', - return_value=expected_type), \ - patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE): - - result = VLLMAscendQuantizer.get_quantizer( - quant_description, - prefix, - packed_modules_mapping={"some": "mapping"}) - - # Verify - self.assertIsNotNone(result) - self.assertEqual(result, - self.supported_types[expected_type]._instance) - self.supported_types[expected_type].assert_called_once_with( - quant_description) - - -class TestW8A8Quantizer(TestBase): - - def setUp(self): - self.quantizer = W8A8Quantizer(quant_description={}) - - def test_build_linear_method(self): - with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod', - return_value=MagicMock()) as mock_linear: - result = self.quantizer.build_linear_method() - mock_linear.assert_called_once_with() - self.assertIsInstance(result, MagicMock) - - def test_build_moe_method(self): - with patch( - 'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod', - return_value=MagicMock()) as mock_linear: - result = self.quantizer.build_moe_method() - mock_linear.assert_called_once_with() - self.assertIsInstance(result, MagicMock) - - def test_build_attention_method(self): - with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod', - return_value=MagicMock()) as mock_linear: - result = self.quantizer.build_attention_method() - mock_linear.assert_called_once_with() - self.assertIsInstance(result, MagicMock) - - -class TestW4A8DYNAMICQuantizer(TestBase): - - def setUp(self): - self.quantizer = W4A8DYNAMICQuantizer(quant_description={}) - - def test_build_linear_method(self): - with patch( - 'vllm_ascend.quantization.quantizer.AscendW4A8DynamicLinearMethod', - return_value=MagicMock()) as mock_linear: - result = self.quantizer.build_linear_method() - mock_linear.assert_called_once_with() - self.assertIsInstance(result, MagicMock) - - def test_build_moe_method(self): - with patch( - 'vllm_ascend.quantization.quantizer.AscendW4A8DynamicFusedMoEMethod', - return_value=MagicMock()) as mock_fused_moe: - result = self.quantizer.build_moe_method() - mock_fused_moe.assert_called_once_with() - self.assertIsInstance(result, MagicMock) diff --git a/tests/ut/quantization/test_utils.py b/tests/ut/quantization/test_utils.py new file mode 100644 index 0000000..153089a --- /dev/null +++ b/tests/ut/quantization/test_utils.py @@ -0,0 +1,62 @@ +import types + +from tests.ut.base import TestBase +from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP, + get_quant_method) + + +class TestGetQuantMethod(TestBase): + + def setUp(self): + self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy( + ) + for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items(): + for layer_type in layer_map.keys(): + ASCEND_QUANTIZATION_METHOD_MAP[quant_type][ + layer_type] = types.new_class(f"{quant_type}_{layer_type}") + + def tearDown(self): + # Restore original map + ASCEND_QUANTIZATION_METHOD_MAP.clear() + ASCEND_QUANTIZATION_METHOD_MAP.update( + self.original_quantization_method_map) + + def test_linear_quant_methods(self): + for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items(): + if "linear" in layer_map.keys(): + prefix = "linear_layer" + cls = layer_map["linear"] + method = get_quant_method({"linear_layer.weight": quant_type}, + prefix, "linear") + self.assertIsInstance(method, cls) + + def test_moe_quant_methods(self): + for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items(): + if "moe" in layer_map.keys(): + prefix = "layer" + cls = layer_map["moe"] + method = get_quant_method({"layer.weight": quant_type}, prefix, + "moe") + self.assertIsInstance(method, cls) + + def test_with_fa_quant_type(self): + quant_description = {"fa_quant_type": "C8"} + method = get_quant_method(quant_description, ".attn", "attention") + self.assertIsInstance( + method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"]) + + def test_with_kv_quant_type(self): + quant_description = {"kv_quant_type": "C8"} + method = get_quant_method(quant_description, ".attn", "attention") + self.assertIsInstance( + method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"]) + + def test_invalid_layer_type(self): + quant_description = {"linear_layer.weight": "W8A8"} + with self.assertRaises(NotImplementedError): + get_quant_method(quant_description, "linear_layer", "unsupported") + + def test_invalid_quant_type(self): + quant_description = {"linear_layer.weight": "UNKNOWN"} + with self.assertRaises(NotImplementedError): + get_quant_method(quant_description, "linear_layer", "linear") diff --git a/tests/ut/torchair/ops/test_torchair_fused_moe.py b/tests/ut/torchair/ops/test_torchair_fused_moe.py index 19df5dc..0f0353d 100644 --- a/tests/ut/torchair/ops/test_torchair_fused_moe.py +++ b/tests/ut/torchair/ops/test_torchair_fused_moe.py @@ -24,7 +24,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase from vllm_ascend.ascend_forward_context import _get_fused_moe_state from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod -from vllm_ascend.quantization.quantizer import W8A8Quantizer from vllm_ascend.torchair.ops.torchair_fused_moe import ( TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod) from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402 @@ -236,12 +235,9 @@ class TestTorchairAscendFusedMoe: mock_quant_method = MockFusedMoEMethod() mock_quant_config.get_quant_method.return_value = mock_quant_method mock_quant_config.is_layer_skipped_ascend.return_value = False - with patch( - 'vllm_ascend.quantization.quantizer.AscendQuantizer.get_quantizer', - return_value=W8A8Quantizer): + with patch("vllm_ascend.quantization.quant_config.get_quant_method"): moe = TorchairAscendFusedMoE(**default_moe_config, quant_config=mock_quant_config) - assert moe.quant_method is not None assert isinstance(moe.quant_method, AscendFusedMoEMethod) diff --git a/tests/ut/torchair/test_utils.py b/tests/ut/torchair/test_utils.py index fb526b5..edd3fc2 100644 --- a/tests/ut/torchair/test_utils.py +++ b/tests/ut/torchair/test_utils.py @@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch import torch from tests.ut.base import TestBase -from vllm_ascend.quantization.quantizer import SUPPORT_ASCEND_QUANTIZER_TYPE from vllm_ascend.torchair import utils @@ -135,15 +134,3 @@ class TestTorchairUtils(TestBase): utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) mock_npu_cast.assert_not_called() - - def test_torchair_quant_method_register(self): - - TorchairW8A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[ - "W8A8_DYNAMIC"] - TorchairW4A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[ - "W4A8_DYNAMIC"] - utils.torchair_quant_method_register() - self.assertNotEqual(TorchairW8A8DYNAMICQuantizer, - SUPPORT_ASCEND_QUANTIZER_TYPE["W8A8_DYNAMIC"]) - self.assertNotEqual(TorchairW4A8DYNAMICQuantizer, - SUPPORT_ASCEND_QUANTIZER_TYPE["W4A8_DYNAMIC"]) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index d449c8d..7299dbe 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -38,7 +38,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD -from .quantizer import AscendQuantizer +from .utils import get_quant_method @register_quantization_config(ASCEND_QUANTIZATION_METHOD) @@ -150,18 +150,15 @@ class AscendQuantConfig(QuantizationConfig): class AscendLinearMethod(LinearMethodBase): """Linear method for Ascend quantization. - This class calls AscendQuantizer to search a specific quantization - implementations supported on ascend hardware for linear methods. - Args: quant_config: The Ascend quantization config. """ def __init__(self, quant_config: AscendQuantConfig, prefix: str, packed_modules_mapping: Dict[str, Any]) -> None: - self.quantizer = AscendQuantizer.get_quantizer( - quant_config.quant_description, prefix, packed_modules_mapping) - self.quant_method = self.quantizer.build_linear_method() + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "linear", + packed_modules_mapping) def create_weights( self, @@ -231,17 +228,13 @@ class AscendLinearMethod(LinearMethodBase): class AscendKVCacheMethod(BaseKVCacheMethod): """KVCache method for Ascend quantization. - This class calls AscendQuantizer to search a specific quantization - implementations supported on ascend hardware for kvcache methods. - Args: quant_config: The Ascend quantization config. """ def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None: - self.quantizer = AscendQuantizer.get_quantizer( - quant_config.quant_description, prefix) - self.quant_method = self.quantizer.build_attention_method() + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "attention") def create_weights(self, layer: torch.nn.Module) -> None: # Different from linear method, there are no weight processing/slicing @@ -263,18 +256,15 @@ class AscendKVCacheMethod(BaseKVCacheMethod): class AscendFusedMoEMethod(FusedMoEMethodBase): """FusedMoE method for Ascend quantization. - This class calls AscendQuantizer to search a specific quantization - implementations supported on ascend hardware for kvcache methods. - Args: quant_config: The Ascend quantization config. """ def __init__(self, quant_config: AscendQuantConfig, prefix: str, packed_modules_mapping: Dict[str, Any]): - self.quantizer = AscendQuantizer.get_quantizer( - quant_config.quant_description, prefix, packed_modules_mapping) - self.quant_method = self.quantizer.build_moe_method() + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "moe", + packed_modules_mapping) def create_weights( self, @@ -344,14 +334,13 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): class AscendEmbeddingMethod(AscendLinearMethod): """Embedding method for Ascend quantization. - This class calls AscendQuantizer to search a specific quantization - implementations supported on ascend hardware for Embedding methods. + Args: quant_config: The Ascend quantization config. """ def __init__(self, quant_config: AscendQuantConfig, prefix: str, packed_modules_mapping: Dict[str, Any]) -> None: - self.quantizer = AscendQuantizer.get_quantizer( - quant_config.quant_description, prefix, packed_modules_mapping) - self.quant_method = self.quantizer.build_linear_method() + self.quant_method = get_quant_method(quant_config.quant_description, + prefix, "linear", + packed_modules_mapping) diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py deleted file mode 100644 index 0e15ed2..0000000 --- a/vllm_ascend/quantization/quantizer.py +++ /dev/null @@ -1,311 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import importlib -import sys -import types -from typing import Any, Dict, List, Optional - -from vllm.logger import logger - -from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init, - wrapper_vocab_parallel_embedding_init) -from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, - AscendW4A8DynamicLinearMethod) -from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, - AscendW8A8LinearMethod) -from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, - AscendW8A8DynamicLinearMethod) - -CUSTOMIZED_QUANTIZER_TYPE: List[str] = [] - - -class AscendQuantizer: - """An interface to different quantization implementations for ascend hardwares.""" - - @classmethod - def get_quantizer(cls, - quant_config: Dict[str, Any], - prefix: str, - packed_modules_mapping: Optional[Dict[str, - Any]] = dict()): - # TODO: Need a param to choose quantization algorithms. - quantization_algorithm = '' - - if quantization_algorithm in CUSTOMIZED_QUANTIZER_TYPE: - return - - return VLLMAscendQuantizer.get_quantizer(quant_config, prefix, - packed_modules_mapping) - - def build_linear_method(self): - raise NotImplementedError - - def build_moe_method(self): - raise NotImplementedError - - def build_attention_method(self): - raise NotImplementedError - - -class VLLMAscendQuantizer: - _instance: Optional[object] = None - patched = False - - def __init__(self, quant_description): - if VLLMAscendQuantizer.patched: - return - for name in quant_description.keys(): - if "norm.bias" in name: - VLLMAscendQuantizer.apply_patch( - "vllm.model_executor.layers.layernorm.RMSNorm", "__init__", - [wrapper_rmsnorm_init]) - VLLMAscendQuantizer.apply_patch( - "vllm_ascend.ops.layernorm.AscendRMSNorm", "forward_oot", - [wrapper_rmsnorm_forward_oot]) - VLLMAscendQuantizer.apply_patch( - "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding", - "__init__", [wrapper_vocab_parallel_embedding_init]) - break - VLLMAscendQuantizer.patched = True - logger.info("Using the vLLM Ascend Quantizer version now!") - - @staticmethod - def apply_patch(target_module, target_function, wrappers): - - original_module, original_function = VLLMAscendQuantizer.parse_path( - target_module, target_function, False) - - original_function_id = id(original_function) - - candidate = original_function - for wrapper in wrappers: - candidate = wrapper(candidate) - if target_function is not None: - setattr(original_module, target_function, candidate) - - for _, value in sys.modules.copy().items(): - if target_function is None: - continue - try: - attr = getattr(value, target_function, None) - if attr is not None and id(attr) == original_function_id: - setattr(value, target_function, candidate) - except ImportError: - continue - - @staticmethod - def parse_path(module_path, function_name, create_dummy): - """ - Parse module path and resolve/create modules as needed. - - Args: - module_path: Dot-separated module path - function_name: Target function name (None for module only) - create_dummy: Create dummy modules/functions when missing - - Returns: - Tuple of (resolved module, target function/none) - - Raises: - ModuleNotFoundError: If module path is invalid and create_dummy=False - AttributeError: If function is missing and create_dummy=False - """ - from importlib.machinery import ModuleSpec - - def create_dummy_module(full_path, parent=None): - """Create and register a placeholder module""" - dummy = types.ModuleType(full_path) - dummy.__file__ = "vllm_ascend.dummy_module.py" - dummy.__spec__ = ModuleSpec(full_path, None) - sys.modules[full_path] = dummy - if parent: - setattr(parent, full_path.split(".")[-1], dummy) - return dummy - - def create_placeholder_function(func_name): - """Create dummy function that raises when called""" - - def placeholder(*args, **kwargs): - raise NotImplementedError( - f"Function {func_name} is a placeholder") - - placeholder.__name__ = func_name - return placeholder - - modules = module_path.split(".") - current_module = None - processed_path = [] - - for idx, part in enumerate(modules): - current_path = ".".join(modules[:idx + 1]) - parent_path = ".".join(modules[:idx]) if idx > 0 else None - - try: - current_module = importlib.import_module(current_path) - except ModuleNotFoundError: - # Handle missing module - parent = importlib.import_module( - parent_path) if parent_path else None - if parent and hasattr(parent, part): - # Use existing attribute from parent - current_module = getattr(parent, part) - # Check for early function resolution - if function_name and hasattr(current_module, - function_name): - return current_module, getattr(current_module, - function_name) - if function_name and create_dummy: - ph_func = create_placeholder_function(function_name) - setattr(current_module, function_name, ph_func) - return current_module, ph_func - if function_name: - raise AttributeError( - f"Function {function_name} missing in {current_path}" - ) - else: - if not create_dummy: - raise - # Create and register dummy module - current_module = create_dummy_module( - current_path, - parent=importlib.import_module(parent_path) - if parent_path else None) - - processed_path.append(part) - - # Final function handling - final_module = sys.modules[module_path] - if function_name is not None: - if not hasattr(final_module, function_name): - if create_dummy: - ph_func = create_placeholder_function(function_name) - setattr(final_module, function_name, ph_func) - else: - setattr(final_module, function_name, None) - return final_module, getattr(final_module, function_name) - - return final_module, None - - @staticmethod - def build_linear_method(): - raise NotImplementedError( - "Linear method is not implemented for the current quant type.") - - @staticmethod - def build_moe_method(): - raise NotImplementedError( - "MoE method is not implemented for the current quant type.") - - @staticmethod - def build_attention_method(): - raise NotImplementedError( - "Attention method is not implemented for the current quant type.") - - @staticmethod - def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, - packed_modules_mapping: Dict[str, Any]): - proj_name = prefix.split(".")[-1] - if proj_name in packed_modules_mapping: - quant_type = None - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in packed_modules_mapping[proj_name] - ] - for shard_prefix in shard_prefixes: - shard_quant_type = quant_description[shard_prefix + '.weight'] - - if quant_type is None: - quant_type = shard_quant_type - elif shard_quant_type != quant_type: - raise ValueError( - f"Not all shards of {prefix} are quantized with same quant type." - f"Shard {proj_name} uses {shard_quant_type}, but another shard" - f"use {quant_type}. Please check quantization config.") - else: - quant_type = quant_description[prefix + '.weight'] - return quant_type - - @classmethod - def get_quantizer(cls, - quant_description: Dict[str, Any], - prefix: str, - packed_modules_mapping: Optional[Dict[str, Any]] = None): - if packed_modules_mapping is None: - packed_modules_mapping = dict() - # Attention - if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): - quant_type = quant_description['fa_quant_type'] - # Use KVCache int8 - elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys(): - quant_type = quant_description['kv_quant_type'] - # Linear - else: - quant_type = cls.get_linear_quant_type(quant_description, prefix, - packed_modules_mapping) - if quant_type in SUPPORT_ASCEND_QUANTIZER_TYPE.keys(): - cls = SUPPORT_ASCEND_QUANTIZER_TYPE[quant_type] - if not cls._instance: - cls._instance = cls(quant_description) - return cls._instance - raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \ - f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}") - - -class W4A8DYNAMICQuantizer(VLLMAscendQuantizer): - - @staticmethod - def build_linear_method(): - return AscendW4A8DynamicLinearMethod() - - @staticmethod - def build_moe_method(): - return AscendW4A8DynamicFusedMoEMethod() - - -class W8A8Quantizer(VLLMAscendQuantizer): - - @staticmethod - def build_linear_method(): - return AscendW8A8LinearMethod() - - @staticmethod - def build_moe_method(): - return AscendW8A8FusedMoEMethod() - - @staticmethod - def build_attention_method(): - return AscendC8KVCacheMethod() - - -class W8A8DYNAMICQuantizer(VLLMAscendQuantizer): - - @staticmethod - def build_linear_method(): - return AscendW8A8DynamicLinearMethod() - - @staticmethod - def build_moe_method(): - return AscendW8A8DynamicFusedMoEMethod() - - -SUPPORT_ASCEND_QUANTIZER_TYPE = { - "W4A8_DYNAMIC": W4A8DYNAMICQuantizer, - "W8A8": W8A8Quantizer, - "W8A8_DYNAMIC": W8A8DYNAMICQuantizer, - "C8": W8A8Quantizer, -} diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py new file mode 100644 index 0000000..6783f12 --- /dev/null +++ b/vllm_ascend/quantization/utils.py @@ -0,0 +1,222 @@ +import importlib +import sys +import types +from typing import Any, Dict, Optional, Type + +from vllm.logger import logger + +from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init, + wrapper_vocab_parallel_embedding_init) +from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, + AscendW4A8DynamicLinearMethod) +from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, + AscendW8A8LinearMethod) +from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, + AscendW8A8DynamicLinearMethod) + +patched = False + +ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = { + "W4A8_DYNAMIC": { + "linear": AscendW4A8DynamicLinearMethod, + "moe": AscendW4A8DynamicFusedMoEMethod, + }, + "W8A8": { + "linear": AscendW8A8LinearMethod, + "moe": AscendW8A8FusedMoEMethod, + "attention": AscendC8KVCacheMethod, + }, + "W8A8_DYNAMIC": { + "linear": AscendW8A8DynamicLinearMethod, + "moe": AscendW8A8DynamicFusedMoEMethod, + }, + "C8": { + "attention": AscendC8KVCacheMethod, + }, +} + + +def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str, + packed_modules_mapping: Dict[str, Any]): + proj_name = prefix.split(".")[-1] + if proj_name in packed_modules_mapping: + quant_type = None + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in packed_modules_mapping[proj_name] + ] + for shard_prefix in shard_prefixes: + shard_quant_type = quant_description[shard_prefix + '.weight'] + + if quant_type is None: + quant_type = shard_quant_type + elif shard_quant_type != quant_type: + raise ValueError( + f"Not all shards of {prefix} are quantized with same quant type." + f"Shard {proj_name} uses {shard_quant_type}, but another shard" + f"use {quant_type}. Please check quantization config.") + else: + quant_type = quant_description[prefix + '.weight'] + return quant_type + + +def get_quant_method(quant_description: Dict[str, Any], + prefix: str, + layer_type: str, + packed_modules_mapping: Optional[Dict[str, Any]] = None): + apply_quantization_patch(quant_description) + if packed_modules_mapping is None: + packed_modules_mapping = dict() + # Attention + if '.attn' in prefix and 'fa_quant_type' in quant_description.keys(): + quant_type = quant_description['fa_quant_type'] + # Use KVCache int8 + elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys(): + quant_type = quant_description['kv_quant_type'] + # Linear + else: + quant_type = get_linear_quant_type(quant_description, prefix, + packed_modules_mapping) + if quant_type in ASCEND_QUANTIZATION_METHOD_MAP.keys(): + method_map = ASCEND_QUANTIZATION_METHOD_MAP[quant_type] + if layer_type in method_map.keys(): + method_cls = method_map[layer_type] + return method_cls() + else: + raise NotImplementedError( + f"Currently, vLLM Ascend doesn't support {quant_type} for {layer_type}." + ) + raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \ + f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}") + + +def apply_quantization_patch(quant_description): + global patched + if patched: + return + for name in quant_description.keys(): + if "norm.bias" in name: + apply_patch("vllm.model_executor.layers.layernorm.RMSNorm", + "__init__", [wrapper_rmsnorm_init]) + apply_patch("vllm_ascend.ops.layernorm.AscendRMSNorm", + "forward_oot", [wrapper_rmsnorm_forward_oot]) + apply_patch( + "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding", + "__init__", [wrapper_vocab_parallel_embedding_init]) + break + patched = True + logger.info("Using the vLLM Ascend Quantization now!") + + +def apply_patch(target_module, target_function, wrappers): + + original_module, original_function = parse_path(target_module, + target_function, False) + + original_function_id = id(original_function) + + candidate = original_function + for wrapper in wrappers: + candidate = wrapper(candidate) + if target_function is not None: + setattr(original_module, target_function, candidate) + + for _, value in sys.modules.copy().items(): + if target_function is None: + continue + try: + attr = getattr(value, target_function, None) + if attr is not None and id(attr) == original_function_id: + setattr(value, target_function, candidate) + except ImportError: + continue + + +def parse_path(module_path, function_name, create_dummy): + """ + Parse module path and resolve/create modules as needed. + + Args: + module_path: Dot-separated module path + function_name: Target function name (None for module only) + create_dummy: Create dummy modules/functions when missing + + Returns: + Tuple of (resolved module, target function/none) + + Raises: + ModuleNotFoundError: If module path is invalid and create_dummy=False + AttributeError: If function is missing and create_dummy=False + """ + from importlib.machinery import ModuleSpec + + def create_dummy_module(full_path, parent=None): + """Create and register a placeholder module""" + dummy = types.ModuleType(full_path) + dummy.__file__ = "vllm_ascend.dummy_module.py" + dummy.__spec__ = ModuleSpec(full_path, None) + sys.modules[full_path] = dummy + if parent: + setattr(parent, full_path.split(".")[-1], dummy) + return dummy + + def create_placeholder_function(func_name): + """Create dummy function that raises when called""" + + def placeholder(*args, **kwargs): + raise NotImplementedError(f"Function {func_name} is a placeholder") + + placeholder.__name__ = func_name + return placeholder + + modules = module_path.split(".") + current_module = None + processed_path = [] + + for idx, part in enumerate(modules): + current_path = ".".join(modules[:idx + 1]) + parent_path = ".".join(modules[:idx]) if idx > 0 else None + + try: + current_module = importlib.import_module(current_path) + except ModuleNotFoundError: + # Handle missing module + parent = importlib.import_module( + parent_path) if parent_path else None + if parent and hasattr(parent, part): + # Use existing attribute from parent + current_module = getattr(parent, part) + # Check for early function resolution + if function_name and hasattr(current_module, function_name): + return current_module, getattr(current_module, + function_name) + if function_name and create_dummy: + ph_func = create_placeholder_function(function_name) + setattr(current_module, function_name, ph_func) + return current_module, ph_func + if function_name: + raise AttributeError( + f"Function {function_name} missing in {current_path}") + else: + if not create_dummy: + raise + # Create and register dummy module + current_module = create_dummy_module( + current_path, + parent=importlib.import_module(parent_path) + if parent_path else None) + + processed_path.append(part) + + # Final function handling + final_module = sys.modules[module_path] + if function_name is not None: + if not hasattr(final_module, function_name): + if create_dummy: + ph_func = create_placeholder_function(function_name) + setattr(final_module, function_name, ph_func) + else: + setattr(final_module, function_name, None) + return final_module, getattr(final_module, function_name) + + return final_module, None diff --git a/vllm_ascend/torchair/quantization/torchair_quantizer.py b/vllm_ascend/torchair/quantization/torchair_quantizer.py deleted file mode 100644 index 1d1d584..0000000 --- a/vllm_ascend/torchair/quantization/torchair_quantizer.py +++ /dev/null @@ -1,29 +0,0 @@ -from vllm_ascend.quantization.quantizer import VLLMAscendQuantizer -from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import ( - TorchairAscendW4A8DynamicFusedMoEMethod, - TorchairAscendW4A8DynamicLinearMethod) -from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( - TorchairAscendW8A8DynamicFusedMoEMethod, - TorchairAscendW8A8DynamicLinearMethod) - - -class TorchairW8A8DYNAMICQuantizer(VLLMAscendQuantizer): - - @staticmethod - def build_linear_method(): - return TorchairAscendW8A8DynamicLinearMethod() - - @staticmethod - def build_moe_method(): - return TorchairAscendW8A8DynamicFusedMoEMethod() - - -class TorchairW4A8DYNAMICQuantizer(VLLMAscendQuantizer): - - @staticmethod - def build_linear_method(): - return TorchairAscendW4A8DynamicLinearMethod() - - @staticmethod - def build_moe_method(): - return TorchairAscendW4A8DynamicFusedMoEMethod() diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index 13d5879..fcf2914 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -180,15 +180,22 @@ def register_torchair_model(): def torchair_quant_method_register(): - from vllm_ascend.quantization.quantizer import \ - SUPPORT_ASCEND_QUANTIZER_TYPE - from vllm_ascend.torchair.quantization.torchair_quantizer import ( - TorchairW4A8DYNAMICQuantizer, TorchairW8A8DYNAMICQuantizer) + from vllm_ascend.quantization.utils import ASCEND_QUANTIZATION_METHOD_MAP + from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import ( + TorchairAscendW4A8DynamicFusedMoEMethod, + TorchairAscendW4A8DynamicLinearMethod) + from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( + TorchairAscendW8A8DynamicFusedMoEMethod, + TorchairAscendW8A8DynamicLinearMethod) - SUPPORT_ASCEND_QUANTIZER_TYPE[ - "W8A8_DYNAMIC"] = TorchairW8A8DYNAMICQuantizer - SUPPORT_ASCEND_QUANTIZER_TYPE[ - "W4A8_DYNAMIC"] = TorchairW4A8DYNAMICQuantizer + ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][ + "linear"] = TorchairAscendW8A8DynamicLinearMethod + ASCEND_QUANTIZATION_METHOD_MAP["W8A8_DYNAMIC"][ + "moe"] = TorchairAscendW8A8DynamicFusedMoEMethod + ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][ + "linear"] = TorchairAscendW4A8DynamicLinearMethod + ASCEND_QUANTIZATION_METHOD_MAP["W4A8_DYNAMIC"][ + "moe"] = TorchairAscendW4A8DynamicFusedMoEMethod def torchair_ops_patch():