From d51694a77bbbeba6b45a54ecb4cb04559266b129 Mon Sep 17 00:00:00 2001 From: 22dimensions Date: Mon, 8 Sep 2025 17:31:53 +0800 Subject: [PATCH] [2/N][Refactor][Quantization] clean quantization patch (#2785) ### What this PR does / why we need it? quantization patch is unused code ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? tested by CI - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/f4962a6d55a340ebb569d377c842deff7611d8f7 Signed-off-by: 22dimensions --- tests/ut/quantization/test_func_wrapper.py | 134 -------------- vllm_ascend/ops/vocab_parallel_embedding.py | 1 + vllm_ascend/quantization/func_wrapper.py | 184 -------------------- vllm_ascend/quantization/utils.py | 139 +-------------- 4 files changed, 2 insertions(+), 456 deletions(-) delete mode 100644 tests/ut/quantization/test_func_wrapper.py delete mode 100644 vllm_ascend/quantization/func_wrapper.py diff --git a/tests/ut/quantization/test_func_wrapper.py b/tests/ut/quantization/test_func_wrapper.py deleted file mode 100644 index 5020f80..0000000 --- a/tests/ut/quantization/test_func_wrapper.py +++ /dev/null @@ -1,134 +0,0 @@ -from unittest.mock import patch - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.quantization.func_wrapper import (wrapper_rmsnorm_forward_oot, - wrapper_rmsnorm_init) - - -class MockRMSNorm: - - def __init__(self, hidden_size: int, **extra_args): - self.hidden_size = hidden_size - self.weight = torch.ones(hidden_size) - self.input_scale = 1.0 - self.input_offset = 0.0 - self.variance_epsilon = 1e-6 - self.bias = torch.nn.Parameter(torch.zeros(hidden_size), - requires_grad=False) - self.ignore_anti = extra_args.get('ignore_anti', True) - - -class TestFuncWrapper(TestBase): - - def test_wrapper_rmsnorm_init(self): - - @wrapper_rmsnorm_init - def init(self, hidden_size: int, **extra_args) -> None: - self.hidden_size = hidden_size - - hidden_size = 128 - extra_args = {'arg1': 'value1'} - - rms_norm = MockRMSNorm(hidden_size, **extra_args) - init(rms_norm, hidden_size, **extra_args) - - self.assertTrue(hasattr(rms_norm, 'ignore_anti')) - self.assertTrue(rms_norm.ignore_anti) - - self.assertTrue(hasattr(rms_norm, 'bias')) - self.assertIsInstance(rms_norm.bias, torch.nn.Parameter) - self.assertEqual(rms_norm.bias.shape, torch.Size([hidden_size])) - self.assertFalse(rms_norm.bias.requires_grad) - - @patch('torch_npu._npu_quant_rms_norm') - def test_wrapper_rmsnorm_forward_oot_with_residual( - self, mock_npu_quant_rms_norm): - hidden_size = 128 - x = torch.randn(hidden_size) - residual = torch.randn(hidden_size) - expected_out = torch.randn(hidden_size) - - mock_npu_quant_rms_norm.return_value = (expected_out, residual) - - @wrapper_rmsnorm_forward_oot - def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None): - return x, residual - - rms_norm = MockRMSNorm(hidden_size) - rms_norm.ignore_anti = False - - output, res = forward_oot(rms_norm, x, residual) - - mock_npu_quant_rms_norm.assert_called_once() - - args, kwargs = mock_npu_quant_rms_norm.call_args - self.assertTrue(torch.equal(args[1], rms_norm.weight)) - self.assertTrue(torch.equal(args[2], rms_norm.bias)) - self.assertEqual(args[3], rms_norm.input_scale) - self.assertEqual(args[4], rms_norm.input_offset) - self.assertEqual(args[5], rms_norm.variance_epsilon) - self.assertTrue(torch.equal(res, residual)) - - @patch('torch_npu._npu_quant_rms_norm') - def test_wrapper_rmsnorm_forward_oot_without_residual( - self, mock_npu_quant_rms_norm): - hidden_size = 128 - x = torch.randn(hidden_size) - expected_out = torch.randn(hidden_size) - - mock_npu_quant_rms_norm.return_value = expected_out - - @wrapper_rmsnorm_forward_oot - def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None): - return x - - rms_norm = MockRMSNorm(hidden_size) - rms_norm.ignore_anti = False - - output = forward_oot(rms_norm, x) - - mock_npu_quant_rms_norm.assert_called_once() - - args, kwargs = mock_npu_quant_rms_norm.call_args - self.assertTrue(torch.equal(args[0], x)) - self.assertTrue(torch.equal(args[1], rms_norm.weight)) - self.assertTrue(torch.equal(args[2], rms_norm.bias)) - self.assertEqual(args[3], rms_norm.input_scale) - self.assertEqual(args[4], rms_norm.input_offset) - self.assertEqual(args[5], rms_norm.variance_epsilon) - - self.assertTrue(torch.equal(output, expected_out)) - - def test_wrapper_rmsnorm_forward_oot_ignore_anti_with_residual(self): - hidden_size = 128 - x = torch.randn(hidden_size) - residual = torch.randn(hidden_size) - - @wrapper_rmsnorm_forward_oot - def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None): - return x, residual - - rms_norm = MockRMSNorm(hidden_size) - rms_norm.ignore_anti = True - - output, res = forward_oot(rms_norm, x, residual) - - self.assertTrue(torch.equal(output, x.add_(rms_norm.bias))) - self.assertTrue(torch.equal(res, residual)) - - def test_wrapper_rmsnorm_forward_oot_ignore_anti_no_residual(self): - hidden_size = 128 - x = torch.randn(hidden_size) - - @wrapper_rmsnorm_forward_oot - def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None): - return x - - rms_norm = MockRMSNorm(hidden_size) - rms_norm.ignore_anti = True - - output = forward_oot(rms_norm, x) - - self.assertTrue(torch.equal(output, x.add_(rms_norm.bias))) diff --git a/vllm_ascend/ops/vocab_parallel_embedding.py b/vllm_ascend/ops/vocab_parallel_embedding.py index 7ad35dc..0a7d7ef 100644 --- a/vllm_ascend/ops/vocab_parallel_embedding.py +++ b/vllm_ascend/ops/vocab_parallel_embedding.py @@ -97,6 +97,7 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding): if params_dtype is None: params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype # Divide the weight matrix along the vocaburaly dimension. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size self.num_embeddings_per_partition = divide(self.num_embeddings_padded, diff --git a/vllm_ascend/quantization/func_wrapper.py b/vllm_ascend/quantization/func_wrapper.py deleted file mode 100644 index 8357695..0000000 --- a/vllm_ascend/quantization/func_wrapper.py +++ /dev/null @@ -1,184 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from typing import Optional, Tuple, Union - -import torch -import torch_npu -from vllm.logger import logger -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import UnquantizedLinearMethod -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig) - - -# func refers to vocabParallelEmbedding.__init__ -def wrapper_vocab_parallel_embedding_init(func): - - def init( - self, - num_embeddings: int, - embedding_dim: int, - params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - func( - self, - num_embeddings, - embedding_dim, - params_dtype, - org_num_embeddings, - padding_size, - quant_config, - prefix, - ) - # TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class. - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - - return init - - -# func refers to RMSNorm.__init__ -def wrapper_rmsnorm_init(func): - - def init(self, hidden_size: int, **extra_args) -> None: - func(self, hidden_size, **extra_args) - self.ignore_anti = True - self.bias = torch.nn.Parameter(torch.zeros(hidden_size), - requires_grad=False) - - return init - - -# func refers to RMSNorm.forward_oot -def wrapper_rmsnorm_forward_oot(func): - - def _rmsnorm_forward_oot( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if not self.ignore_anti: - if residual is not None: - residual += x - out = torch_npu._npu_quant_rms_norm( - residual, - self.weight, - self.bias, - self.input_scale, - self.input_offset, - self.variance_epsilon, - ) - return out, residual - out = torch_npu._npu_quant_rms_norm( - x, - self.weight, - self.bias, - self.input_scale, - self.input_offset, - self.variance_epsilon, - ) - return out - - if residual is not None: - x, residual = func(self, x, residual) - return x.add_(self.bias), residual - - return func(self, x).add_(self.bias) - - return _rmsnorm_forward_oot - - -MODEL_LAYER_MAPPING = { - "LlamaModel": { - "attn": { - "layer_attr": "self_attn", - "proj_attr": "qkv_proj", - "norm_attr": "input_layernorm", - "unquantized_type": UnquantizedLinearMethod, - }, - "mlp": { - "layer_attr": "mlp", - "proj_attr": "gate_up_proj", - "norm_attr": "post_attention_layernorm", - "unquantized_type": UnquantizedLinearMethod, - }, - }, -} - - -def wrapper_load_model(func): - - def postprocess_loading(self) -> None: - func(self) - - def process_layer(layer, idx, mapping): - - def process_module(module_cfg, layer_obj): - if module_cfg is None: - return - - module_obj = getattr(layer_obj, module_cfg["layer_attr"], None) - if module_obj is None: - return - - proj_attr = module_cfg["proj_attr"] - if callable(proj_attr): - proj = proj_attr(module_obj, idx) - else: - proj = getattr(module_obj, proj_attr, None) - - norm = getattr(layer_obj, module_cfg["norm_attr"], None) - - if proj is None or norm is None: - return - - norm.ignore_anti = isinstance(proj.quant_method, - module_cfg["unquantized_type"]) - if not norm.ignore_anti: - for param_name in ["input_scale", "input_offset"]: - if hasattr(proj, param_name): - param = getattr(proj, param_name) - norm.register_parameter( - param_name, - torch.nn.Parameter(param.clone(), - requires_grad=False)) - - process_module(mapping.get("attn"), layer) - process_module(mapping.get("mlp"), layer) - - model_type = self.model.model.__class__.__name__ - mapping = MODEL_LAYER_MAPPING.get(model_type) - - if not mapping: - logger.info( - f"Warning: Model type '{model_type}' not found in MODEL_LAYER_MAPPING. Skipping layer mapping." - ) - return - - for idx, layer in enumerate(self.model.model.layers): - process_layer(layer, idx, mapping) - - if isinstance(self.model.model.norm, RMSNorm): - self.model.model.norm.ignore_anti = True - - return postprocess_loading diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index 6783f12..f4cd0d0 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -1,12 +1,7 @@ -import importlib -import sys -import types from typing import Any, Dict, Optional, Type from vllm.logger import logger -from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init, - wrapper_vocab_parallel_embedding_init) from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod) from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, @@ -64,7 +59,7 @@ def get_quant_method(quant_description: Dict[str, Any], prefix: str, layer_type: str, packed_modules_mapping: Optional[Dict[str, Any]] = None): - apply_quantization_patch(quant_description) + logger.info_once("Using the vLLM Ascend Quantization now!") if packed_modules_mapping is None: packed_modules_mapping = dict() # Attention @@ -88,135 +83,3 @@ def get_quant_method(quant_description: Dict[str, Any], ) raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \ f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}") - - -def apply_quantization_patch(quant_description): - global patched - if patched: - return - for name in quant_description.keys(): - if "norm.bias" in name: - apply_patch("vllm.model_executor.layers.layernorm.RMSNorm", - "__init__", [wrapper_rmsnorm_init]) - apply_patch("vllm_ascend.ops.layernorm.AscendRMSNorm", - "forward_oot", [wrapper_rmsnorm_forward_oot]) - apply_patch( - "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding", - "__init__", [wrapper_vocab_parallel_embedding_init]) - break - patched = True - logger.info("Using the vLLM Ascend Quantization now!") - - -def apply_patch(target_module, target_function, wrappers): - - original_module, original_function = parse_path(target_module, - target_function, False) - - original_function_id = id(original_function) - - candidate = original_function - for wrapper in wrappers: - candidate = wrapper(candidate) - if target_function is not None: - setattr(original_module, target_function, candidate) - - for _, value in sys.modules.copy().items(): - if target_function is None: - continue - try: - attr = getattr(value, target_function, None) - if attr is not None and id(attr) == original_function_id: - setattr(value, target_function, candidate) - except ImportError: - continue - - -def parse_path(module_path, function_name, create_dummy): - """ - Parse module path and resolve/create modules as needed. - - Args: - module_path: Dot-separated module path - function_name: Target function name (None for module only) - create_dummy: Create dummy modules/functions when missing - - Returns: - Tuple of (resolved module, target function/none) - - Raises: - ModuleNotFoundError: If module path is invalid and create_dummy=False - AttributeError: If function is missing and create_dummy=False - """ - from importlib.machinery import ModuleSpec - - def create_dummy_module(full_path, parent=None): - """Create and register a placeholder module""" - dummy = types.ModuleType(full_path) - dummy.__file__ = "vllm_ascend.dummy_module.py" - dummy.__spec__ = ModuleSpec(full_path, None) - sys.modules[full_path] = dummy - if parent: - setattr(parent, full_path.split(".")[-1], dummy) - return dummy - - def create_placeholder_function(func_name): - """Create dummy function that raises when called""" - - def placeholder(*args, **kwargs): - raise NotImplementedError(f"Function {func_name} is a placeholder") - - placeholder.__name__ = func_name - return placeholder - - modules = module_path.split(".") - current_module = None - processed_path = [] - - for idx, part in enumerate(modules): - current_path = ".".join(modules[:idx + 1]) - parent_path = ".".join(modules[:idx]) if idx > 0 else None - - try: - current_module = importlib.import_module(current_path) - except ModuleNotFoundError: - # Handle missing module - parent = importlib.import_module( - parent_path) if parent_path else None - if parent and hasattr(parent, part): - # Use existing attribute from parent - current_module = getattr(parent, part) - # Check for early function resolution - if function_name and hasattr(current_module, function_name): - return current_module, getattr(current_module, - function_name) - if function_name and create_dummy: - ph_func = create_placeholder_function(function_name) - setattr(current_module, function_name, ph_func) - return current_module, ph_func - if function_name: - raise AttributeError( - f"Function {function_name} missing in {current_path}") - else: - if not create_dummy: - raise - # Create and register dummy module - current_module = create_dummy_module( - current_path, - parent=importlib.import_module(parent_path) - if parent_path else None) - - processed_path.append(part) - - # Final function handling - final_module = sys.modules[module_path] - if function_name is not None: - if not hasattr(final_module, function_name): - if create_dummy: - ph_func = create_placeholder_function(function_name) - setattr(final_module, function_name, ph_func) - else: - setattr(final_module, function_name, None) - return final_module, getattr(final_module, function_name) - - return final_module, None