[2/N][Refactor][Quantization] clean quantization patch (#2785)

### What this PR does / why we need it?
quantization patch is unused code

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
tested by CI

- vLLM version: v0.10.1.1
- vLLM main:
f4962a6d55

Signed-off-by: 22dimensions <waitingwind@foxmail.com>
This commit is contained in:
22dimensions
2025-09-08 17:31:53 +08:00
committed by GitHub
parent cd88f89267
commit d51694a77b
4 changed files with 2 additions and 456 deletions

View File

@@ -1,134 +0,0 @@
from unittest.mock import patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.func_wrapper import (wrapper_rmsnorm_forward_oot,
wrapper_rmsnorm_init)
class MockRMSNorm:
def __init__(self, hidden_size: int, **extra_args):
self.hidden_size = hidden_size
self.weight = torch.ones(hidden_size)
self.input_scale = 1.0
self.input_offset = 0.0
self.variance_epsilon = 1e-6
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
self.ignore_anti = extra_args.get('ignore_anti', True)
class TestFuncWrapper(TestBase):
def test_wrapper_rmsnorm_init(self):
@wrapper_rmsnorm_init
def init(self, hidden_size: int, **extra_args) -> None:
self.hidden_size = hidden_size
hidden_size = 128
extra_args = {'arg1': 'value1'}
rms_norm = MockRMSNorm(hidden_size, **extra_args)
init(rms_norm, hidden_size, **extra_args)
self.assertTrue(hasattr(rms_norm, 'ignore_anti'))
self.assertTrue(rms_norm.ignore_anti)
self.assertTrue(hasattr(rms_norm, 'bias'))
self.assertIsInstance(rms_norm.bias, torch.nn.Parameter)
self.assertEqual(rms_norm.bias.shape, torch.Size([hidden_size]))
self.assertFalse(rms_norm.bias.requires_grad)
@patch('torch_npu._npu_quant_rms_norm')
def test_wrapper_rmsnorm_forward_oot_with_residual(
self, mock_npu_quant_rms_norm):
hidden_size = 128
x = torch.randn(hidden_size)
residual = torch.randn(hidden_size)
expected_out = torch.randn(hidden_size)
mock_npu_quant_rms_norm.return_value = (expected_out, residual)
@wrapper_rmsnorm_forward_oot
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
return x, residual
rms_norm = MockRMSNorm(hidden_size)
rms_norm.ignore_anti = False
output, res = forward_oot(rms_norm, x, residual)
mock_npu_quant_rms_norm.assert_called_once()
args, kwargs = mock_npu_quant_rms_norm.call_args
self.assertTrue(torch.equal(args[1], rms_norm.weight))
self.assertTrue(torch.equal(args[2], rms_norm.bias))
self.assertEqual(args[3], rms_norm.input_scale)
self.assertEqual(args[4], rms_norm.input_offset)
self.assertEqual(args[5], rms_norm.variance_epsilon)
self.assertTrue(torch.equal(res, residual))
@patch('torch_npu._npu_quant_rms_norm')
def test_wrapper_rmsnorm_forward_oot_without_residual(
self, mock_npu_quant_rms_norm):
hidden_size = 128
x = torch.randn(hidden_size)
expected_out = torch.randn(hidden_size)
mock_npu_quant_rms_norm.return_value = expected_out
@wrapper_rmsnorm_forward_oot
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
return x
rms_norm = MockRMSNorm(hidden_size)
rms_norm.ignore_anti = False
output = forward_oot(rms_norm, x)
mock_npu_quant_rms_norm.assert_called_once()
args, kwargs = mock_npu_quant_rms_norm.call_args
self.assertTrue(torch.equal(args[0], x))
self.assertTrue(torch.equal(args[1], rms_norm.weight))
self.assertTrue(torch.equal(args[2], rms_norm.bias))
self.assertEqual(args[3], rms_norm.input_scale)
self.assertEqual(args[4], rms_norm.input_offset)
self.assertEqual(args[5], rms_norm.variance_epsilon)
self.assertTrue(torch.equal(output, expected_out))
def test_wrapper_rmsnorm_forward_oot_ignore_anti_with_residual(self):
hidden_size = 128
x = torch.randn(hidden_size)
residual = torch.randn(hidden_size)
@wrapper_rmsnorm_forward_oot
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
return x, residual
rms_norm = MockRMSNorm(hidden_size)
rms_norm.ignore_anti = True
output, res = forward_oot(rms_norm, x, residual)
self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))
self.assertTrue(torch.equal(res, residual))
def test_wrapper_rmsnorm_forward_oot_ignore_anti_no_residual(self):
hidden_size = 128
x = torch.randn(hidden_size)
@wrapper_rmsnorm_forward_oot
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
return x
rms_norm = MockRMSNorm(hidden_size)
rms_norm.ignore_anti = True
output = forward_oot(rms_norm, x)
self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))

View File

@@ -97,6 +97,7 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,

View File

@@ -1,184 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Tuple, Union
import torch
import torch_npu
from vllm.logger import logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig)
# func refers to vocabParallelEmbedding.__init__
def wrapper_vocab_parallel_embedding_init(func):
def init(
self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
func(
self,
num_embeddings,
embedding_dim,
params_dtype,
org_num_embeddings,
padding_size,
quant_config,
prefix,
)
# TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class.
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
return init
# func refers to RMSNorm.__init__
def wrapper_rmsnorm_init(func):
def init(self, hidden_size: int, **extra_args) -> None:
func(self, hidden_size, **extra_args)
self.ignore_anti = True
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
return init
# func refers to RMSNorm.forward_oot
def wrapper_rmsnorm_forward_oot(func):
def _rmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not self.ignore_anti:
if residual is not None:
residual += x
out = torch_npu._npu_quant_rms_norm(
residual,
self.weight,
self.bias,
self.input_scale,
self.input_offset,
self.variance_epsilon,
)
return out, residual
out = torch_npu._npu_quant_rms_norm(
x,
self.weight,
self.bias,
self.input_scale,
self.input_offset,
self.variance_epsilon,
)
return out
if residual is not None:
x, residual = func(self, x, residual)
return x.add_(self.bias), residual
return func(self, x).add_(self.bias)
return _rmsnorm_forward_oot
MODEL_LAYER_MAPPING = {
"LlamaModel": {
"attn": {
"layer_attr": "self_attn",
"proj_attr": "qkv_proj",
"norm_attr": "input_layernorm",
"unquantized_type": UnquantizedLinearMethod,
},
"mlp": {
"layer_attr": "mlp",
"proj_attr": "gate_up_proj",
"norm_attr": "post_attention_layernorm",
"unquantized_type": UnquantizedLinearMethod,
},
},
}
def wrapper_load_model(func):
def postprocess_loading(self) -> None:
func(self)
def process_layer(layer, idx, mapping):
def process_module(module_cfg, layer_obj):
if module_cfg is None:
return
module_obj = getattr(layer_obj, module_cfg["layer_attr"], None)
if module_obj is None:
return
proj_attr = module_cfg["proj_attr"]
if callable(proj_attr):
proj = proj_attr(module_obj, idx)
else:
proj = getattr(module_obj, proj_attr, None)
norm = getattr(layer_obj, module_cfg["norm_attr"], None)
if proj is None or norm is None:
return
norm.ignore_anti = isinstance(proj.quant_method,
module_cfg["unquantized_type"])
if not norm.ignore_anti:
for param_name in ["input_scale", "input_offset"]:
if hasattr(proj, param_name):
param = getattr(proj, param_name)
norm.register_parameter(
param_name,
torch.nn.Parameter(param.clone(),
requires_grad=False))
process_module(mapping.get("attn"), layer)
process_module(mapping.get("mlp"), layer)
model_type = self.model.model.__class__.__name__
mapping = MODEL_LAYER_MAPPING.get(model_type)
if not mapping:
logger.info(
f"Warning: Model type '{model_type}' not found in MODEL_LAYER_MAPPING. Skipping layer mapping."
)
return
for idx, layer in enumerate(self.model.model.layers):
process_layer(layer, idx, mapping)
if isinstance(self.model.model.norm, RMSNorm):
self.model.model.norm.ignore_anti = True
return postprocess_loading

View File

@@ -1,12 +1,7 @@
import importlib
import sys
import types
from typing import Any, Dict, Optional, Type
from vllm.logger import logger
from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init,
wrapper_vocab_parallel_embedding_init)
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
AscendW4A8DynamicLinearMethod)
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
@@ -64,7 +59,7 @@ def get_quant_method(quant_description: Dict[str, Any],
prefix: str,
layer_type: str,
packed_modules_mapping: Optional[Dict[str, Any]] = None):
apply_quantization_patch(quant_description)
logger.info_once("Using the vLLM Ascend Quantization now!")
if packed_modules_mapping is None:
packed_modules_mapping = dict()
# Attention
@@ -88,135 +83,3 @@ def get_quant_method(quant_description: Dict[str, Any],
)
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
f"{list(ASCEND_QUANTIZATION_METHOD_MAP.keys())}")
def apply_quantization_patch(quant_description):
global patched
if patched:
return
for name in quant_description.keys():
if "norm.bias" in name:
apply_patch("vllm.model_executor.layers.layernorm.RMSNorm",
"__init__", [wrapper_rmsnorm_init])
apply_patch("vllm_ascend.ops.layernorm.AscendRMSNorm",
"forward_oot", [wrapper_rmsnorm_forward_oot])
apply_patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding",
"__init__", [wrapper_vocab_parallel_embedding_init])
break
patched = True
logger.info("Using the vLLM Ascend Quantization now!")
def apply_patch(target_module, target_function, wrappers):
original_module, original_function = parse_path(target_module,
target_function, False)
original_function_id = id(original_function)
candidate = original_function
for wrapper in wrappers:
candidate = wrapper(candidate)
if target_function is not None:
setattr(original_module, target_function, candidate)
for _, value in sys.modules.copy().items():
if target_function is None:
continue
try:
attr = getattr(value, target_function, None)
if attr is not None and id(attr) == original_function_id:
setattr(value, target_function, candidate)
except ImportError:
continue
def parse_path(module_path, function_name, create_dummy):
"""
Parse module path and resolve/create modules as needed.
Args:
module_path: Dot-separated module path
function_name: Target function name (None for module only)
create_dummy: Create dummy modules/functions when missing
Returns:
Tuple of (resolved module, target function/none)
Raises:
ModuleNotFoundError: If module path is invalid and create_dummy=False
AttributeError: If function is missing and create_dummy=False
"""
from importlib.machinery import ModuleSpec
def create_dummy_module(full_path, parent=None):
"""Create and register a placeholder module"""
dummy = types.ModuleType(full_path)
dummy.__file__ = "vllm_ascend.dummy_module.py"
dummy.__spec__ = ModuleSpec(full_path, None)
sys.modules[full_path] = dummy
if parent:
setattr(parent, full_path.split(".")[-1], dummy)
return dummy
def create_placeholder_function(func_name):
"""Create dummy function that raises when called"""
def placeholder(*args, **kwargs):
raise NotImplementedError(f"Function {func_name} is a placeholder")
placeholder.__name__ = func_name
return placeholder
modules = module_path.split(".")
current_module = None
processed_path = []
for idx, part in enumerate(modules):
current_path = ".".join(modules[:idx + 1])
parent_path = ".".join(modules[:idx]) if idx > 0 else None
try:
current_module = importlib.import_module(current_path)
except ModuleNotFoundError:
# Handle missing module
parent = importlib.import_module(
parent_path) if parent_path else None
if parent and hasattr(parent, part):
# Use existing attribute from parent
current_module = getattr(parent, part)
# Check for early function resolution
if function_name and hasattr(current_module, function_name):
return current_module, getattr(current_module,
function_name)
if function_name and create_dummy:
ph_func = create_placeholder_function(function_name)
setattr(current_module, function_name, ph_func)
return current_module, ph_func
if function_name:
raise AttributeError(
f"Function {function_name} missing in {current_path}")
else:
if not create_dummy:
raise
# Create and register dummy module
current_module = create_dummy_module(
current_path,
parent=importlib.import_module(parent_path)
if parent_path else None)
processed_path.append(part)
# Final function handling
final_module = sys.modules[module_path]
if function_name is not None:
if not hasattr(final_module, function_name):
if create_dummy:
ph_func = create_placeholder_function(function_name)
setattr(final_module, function_name, ph_func)
else:
setattr(final_module, function_name, None)
return final_module, getattr(final_module, function_name)
return final_module, None