### What this PR does / why we need it? Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model has [quantize filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27). If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply, otherwise will use VLLMAscendQuantizer directly. - This patch fix installation docs to make installation work - This patch enable norm quantization by patch `RMSNorm.__init__`, `RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model` - Add `AscendW8A8LinearMethod` for W8A8 - Add `AscendW8A8DynamicLinearMethod` and `AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC - Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` ### Does this PR introduce _any_ user-facing change? Yes, support w8a8 quantization. After this patch supported, users can use below commands to run w8a8 models: ``` vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B" ``` ### How was this patch tested? 0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 1. From @Yikun: I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls refer to https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613 2. From @dingdingchaomian : Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both models were quantized using Ascend's msmodelslim tool: - Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one for w8a8 dynamic. - Deepseek-v2-lite-chat were tested once because its quantization used both static and dynamic w8a8. Models were tested using both off line inference and online serving, and both work well. The inference codes are exactly the same with the examples in https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with model path and tensor parallel number changed. --------- Signed-off-by: dingdingchaomian <wangce21@huawei.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: dingdingchaomian <wangce21@huawei.com> Co-authored-by: Angazenn <zengyanjia@huawei.com> Co-authored-by: liujiaxu <liujiaxu4@huawei.com> Co-authored-by: ApsarasX <apsarax@outlook.com> Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
288 lines
11 KiB
Python
288 lines
11 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import importlib
|
|
import sys
|
|
import types
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from vllm.logger import logger
|
|
|
|
from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot,
|
|
wrapper_rmsnorm_init)
|
|
from .w8a8 import AscendW8A8LinearMethod
|
|
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
|
AscendW8A8DynamicLinearMethod)
|
|
|
|
CUSTOMIZED_QUANTIZER_TYPE: List[str] = []
|
|
|
|
|
|
class AscendQuantizer:
|
|
"""An interface to different quantization implementations for ascend hardwares."""
|
|
|
|
@classmethod
|
|
def get_quantizer(cls,
|
|
quant_config: Dict[str, Any],
|
|
prefix: str,
|
|
packed_modules_mapping: Optional[Dict[str,
|
|
Any]] = dict()):
|
|
# TODO: Need a param to choose quantization algorithms.
|
|
quantization_algorithm = ''
|
|
|
|
if quantization_algorithm in CUSTOMIZED_QUANTIZER_TYPE:
|
|
return
|
|
|
|
try:
|
|
module = importlib.import_module("mindie_turbo")
|
|
MindIETurboQuantizer = module.MindIETurboQuantizer
|
|
return MindIETurboQuantizer.get_quantizer(quant_config, prefix,
|
|
packed_modules_mapping)
|
|
except ImportError:
|
|
return VLLMAscendQuantizer.get_quantizer(quant_config, prefix,
|
|
packed_modules_mapping)
|
|
|
|
def build_linear_method(self):
|
|
raise NotImplementedError
|
|
|
|
def build_moe_method(self):
|
|
raise NotImplementedError
|
|
|
|
def build_attention_method(self):
|
|
raise NotImplementedError
|
|
|
|
|
|
class VLLMAscendQuantizer:
|
|
_instance: Optional[object] = None
|
|
patched = False
|
|
|
|
def __init__(self, quant_description):
|
|
if VLLMAscendQuantizer.patched:
|
|
return
|
|
for name in quant_description.keys():
|
|
if "norm.bias" in name:
|
|
VLLMAscendQuantizer.apply_patch(
|
|
"vllm.model_executor.layers.layernorm.RMSNorm", "__init__",
|
|
[wrapper_rmsnorm_init])
|
|
VLLMAscendQuantizer.apply_patch(
|
|
"vllm.model_executor.layers.layernorm.RMSNorm",
|
|
"forward_oot", [wrapper_rmsnorm_forward_oot])
|
|
VLLMAscendQuantizer.apply_patch(
|
|
"vllm_ascend.worker.model_runner.NPUModelRunnerBase",
|
|
"load_model", [wrapper_load_model])
|
|
break
|
|
VLLMAscendQuantizer.patched = True
|
|
logger.info("Using the vLLM Ascend Quantizer version now!")
|
|
|
|
@staticmethod
|
|
def apply_patch(target_module, target_function, wrappers):
|
|
|
|
original_module, original_function = VLLMAscendQuantizer.parse_path(
|
|
target_module, target_function, False)
|
|
|
|
original_function_id = id(original_function)
|
|
|
|
candidate = original_function
|
|
for wrapper in wrappers:
|
|
candidate = wrapper(candidate)
|
|
if target_function is not None:
|
|
setattr(original_module, target_function, candidate)
|
|
|
|
for key, value in sys.modules.copy().items():
|
|
if (target_function is not None
|
|
and hasattr(value, target_function)
|
|
and id(getattr(value,
|
|
target_function)) == original_function_id):
|
|
setattr(value, target_function, candidate)
|
|
|
|
@staticmethod
|
|
def parse_path(module_path, function_name, create_dummy):
|
|
"""
|
|
Parse module path and resolve/create modules as needed.
|
|
|
|
Args:
|
|
module_path: Dot-separated module path
|
|
function_name: Target function name (None for module only)
|
|
create_dummy: Create dummy modules/functions when missing
|
|
|
|
Returns:
|
|
Tuple of (resolved module, target function/none)
|
|
|
|
Raises:
|
|
ModuleNotFoundError: If module path is invalid and create_dummy=False
|
|
AttributeError: If function is missing and create_dummy=False
|
|
"""
|
|
from importlib.machinery import ModuleSpec
|
|
|
|
def create_dummy_module(full_path, parent=None):
|
|
"""Create and register a placeholder module"""
|
|
dummy = types.ModuleType(full_path)
|
|
dummy.__file__ = "vllm_ascend.dummy_module.py"
|
|
dummy.__spec__ = ModuleSpec(full_path, None)
|
|
sys.modules[full_path] = dummy
|
|
if parent:
|
|
setattr(parent, full_path.split(".")[-1], dummy)
|
|
return dummy
|
|
|
|
def create_placeholder_function(func_name):
|
|
"""Create dummy function that raises when called"""
|
|
|
|
def placeholder(*args, **kwargs):
|
|
raise NotImplementedError(
|
|
f"Function {func_name} is a placeholder")
|
|
|
|
placeholder.__name__ = func_name
|
|
return placeholder
|
|
|
|
modules = module_path.split(".")
|
|
current_module = None
|
|
processed_path = []
|
|
|
|
for idx, part in enumerate(modules):
|
|
current_path = ".".join(modules[:idx + 1])
|
|
parent_path = ".".join(modules[:idx]) if idx > 0 else None
|
|
|
|
try:
|
|
current_module = importlib.import_module(current_path)
|
|
except ModuleNotFoundError:
|
|
# Handle missing module
|
|
parent = importlib.import_module(
|
|
parent_path) if parent_path else None
|
|
if parent and hasattr(parent, part):
|
|
# Use existing attribute from parent
|
|
current_module = getattr(parent, part)
|
|
# Check for early function resolution
|
|
if function_name and hasattr(current_module,
|
|
function_name):
|
|
return current_module, getattr(current_module,
|
|
function_name)
|
|
if function_name and create_dummy:
|
|
ph_func = create_placeholder_function(function_name)
|
|
setattr(current_module, function_name, ph_func)
|
|
return current_module, ph_func
|
|
if function_name:
|
|
raise AttributeError(
|
|
f"Function {function_name} missing in {current_path}"
|
|
)
|
|
else:
|
|
if not create_dummy:
|
|
raise
|
|
# Create and register dummy module
|
|
current_module = create_dummy_module(
|
|
current_path,
|
|
parent=importlib.import_module(parent_path)
|
|
if parent_path else None)
|
|
|
|
processed_path.append(part)
|
|
|
|
# Final function handling
|
|
final_module = sys.modules[module_path]
|
|
if function_name is not None:
|
|
if not hasattr(final_module, function_name):
|
|
if create_dummy:
|
|
ph_func = create_placeholder_function(function_name)
|
|
setattr(final_module, function_name, ph_func)
|
|
else:
|
|
setattr(final_module, function_name, None)
|
|
return final_module, getattr(final_module, function_name)
|
|
|
|
return final_module, None
|
|
|
|
@staticmethod
|
|
def build_linear_method():
|
|
raise NotImplementedError(
|
|
"Linear method is not implemented for the current quant type.")
|
|
|
|
@staticmethod
|
|
def build_moe_method():
|
|
raise NotImplementedError(
|
|
"MoE method is not implemented for the current quant type.")
|
|
|
|
@staticmethod
|
|
def build_attention_method():
|
|
raise NotImplementedError(
|
|
"Attention method is not implemented for the current quant type.")
|
|
|
|
@staticmethod
|
|
def get_linear_quant_type(quant_description: Dict[str, Any], prefix: str,
|
|
packed_modules_mapping: Dict[str, Any]):
|
|
proj_name = prefix.split(".")[-1]
|
|
if proj_name in packed_modules_mapping:
|
|
quant_type = None
|
|
shard_prefixes = [
|
|
prefix.replace(proj_name, shard_proj_name)
|
|
for shard_proj_name in packed_modules_mapping[proj_name]
|
|
]
|
|
for shard_prefix in shard_prefixes:
|
|
shard_quant_type = quant_description[shard_prefix + '.weight']
|
|
|
|
if quant_type is None:
|
|
quant_type = shard_quant_type
|
|
elif shard_quant_type != quant_type:
|
|
raise ValueError(
|
|
f"Not all shards of {prefix} are quantized with same quant type."
|
|
f"Shard {proj_name} uses {shard_quant_type}, but another shard"
|
|
f"use {quant_type}. Please check quantization config.")
|
|
else:
|
|
quant_type = quant_description[prefix + '.weight']
|
|
return quant_type
|
|
|
|
@classmethod
|
|
def get_quantizer(cls,
|
|
quant_description: Dict[str, Any],
|
|
prefix: str,
|
|
packed_modules_mapping: Optional[Dict[str, Any]] = None):
|
|
if packed_modules_mapping is None:
|
|
packed_modules_mapping = dict()
|
|
# Attention
|
|
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
|
|
quant_type = quant_description['fa_quant_type']
|
|
# Linear
|
|
else:
|
|
quant_type = cls.get_linear_quant_type(quant_description, prefix,
|
|
packed_modules_mapping)
|
|
if quant_type in SUPPORT_ASCEND_QUANTIZER_TYPE.keys():
|
|
cls = SUPPORT_ASCEND_QUANTIZER_TYPE[quant_type]
|
|
if not cls._instance:
|
|
cls._instance = cls(quant_description)
|
|
return cls._instance
|
|
raise NotImplementedError("Currently, vLLM Ascend only supports following quant types:" \
|
|
f"{list(SUPPORT_ASCEND_QUANTIZER_TYPE.keys())}")
|
|
|
|
|
|
class W8A8Quantizer(VLLMAscendQuantizer):
|
|
|
|
@staticmethod
|
|
def build_linear_method():
|
|
return AscendW8A8LinearMethod()
|
|
|
|
|
|
class W8A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
|
|
|
@staticmethod
|
|
def build_linear_method():
|
|
return AscendW8A8DynamicLinearMethod()
|
|
|
|
@staticmethod
|
|
def build_moe_method():
|
|
return AscendW8A8DynamicFusedMoEMethod()
|
|
|
|
|
|
SUPPORT_ASCEND_QUANTIZER_TYPE = {
|
|
"W8A8": W8A8Quantizer,
|
|
"W8A8_DYNAMIC": W8A8DYNAMICQuantizer,
|
|
}
|