Files
xc-llm-ascend/vllm_ascend/quantization/func_wrapper.py
Yikun Jiang 12cae04db9 [quantization] Support w8a8 quantization (#580)
### What this PR does / why we need it?

Add a `VLLMAscendQuantizer` to support w8a8 static (W8A8) and dynamic on
linear and moe (W8A8_DYNAMIC), the quantizer will be enable if a model
has [quantize
filed](https://huggingface.co/vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8/blob/main/config.json#L27).
If MindIE Turbo is installed, the MindIE Turbo Quantizer will apply,
otherwise will use VLLMAscendQuantizer directly.

- This patch fix installation docs to make installation work
- This patch enable norm quantization by patch `RMSNorm.__init__`,
`RMSNorm.forward_oot`, `NPUModelRunnerBase.load_model`
- Add `AscendW8A8LinearMethod` for W8A8
- Add `AscendW8A8DynamicLinearMethod` and
`AscendW8A8DynamicFusedMoEMethod` for W8A8_DYNAMIC
- Add a e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8` 

### Does this PR introduce _any_ user-facing change?
Yes, support w8a8 quantization. After this patch supported, users can
use below commands to run w8a8 models:

```
vllm serve /root/.cache/modelscope/hub/Qwen/Qwen2.5-7B-Instruct-w8a8 --served-model-name "qwen2.5-7B"
```

### How was this patch tested?
0. CI passed: add e2e test for `vllm-ascend/Qwen2.5-0.5B-Instruct-w8a8`
1. From @Yikun:
I test Qwen2.5-0.5B-Instruct-w8a8 for functional test all is well, pls
refer to
https://github.com/vllm-project/vllm-ascend/pull/580#issuecomment-2816747613

2. From @dingdingchaomian :
Use qwen2.5-72b-instruct model and deepseek-v2-lite-chat tested, both
models were quantized using Ascend's msmodelslim tool:
- Qwen2.5-72b-instruct were tested twice, one for w8a8 static and one
for w8a8 dynamic.
- Deepseek-v2-lite-chat were tested once because its quantization used
both static and dynamic w8a8.

Models were tested using both off line inference and online serving, and
both work well. The inference codes are exactly the same with the
examples in
https://vllm-ascend.readthedocs.io/en/latest/quick_start.html, with
model path and tensor parallel number changed.

---------

Signed-off-by: dingdingchaomian <wangce21@huawei.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: dingdingchaomian <wangce21@huawei.com>
Co-authored-by: Angazenn <zengyanjia@huawei.com>
Co-authored-by: liujiaxu <liujiaxu4@huawei.com>
Co-authored-by: ApsarasX <apsarax@outlook.com>
Co-authored-by: ganyi1996ppo <pleaplusone.gy@gmail.com>
2025-04-20 18:14:05 +08:00

152 lines
4.8 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Tuple, Union
import torch
import torch_npu
from vllm.logger import logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
# func refers to RMSNorm.__init__
def wrapper_rmsnorm_init(func):
def init(self, hidden_size: int, **extra_args) -> None:
func(self, hidden_size, **extra_args)
self.ignore_anti = True
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
return init
# func refers to RMSNorm.forward_oot
def wrapper_rmsnorm_forward_oot(func):
def _rmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not self.ignore_anti:
if residual is not None:
residual += x
out = torch_npu._npu_quant_rms_norm(
residual,
self.weight,
self.bias,
self.input_scale,
self.input_offset,
self.variance_epsilon,
)
return out, residual
out = torch_npu._npu_quant_rms_norm(
x,
self.weight,
self.bias,
self.input_scale,
self.input_offset,
self.variance_epsilon,
)
return out
if residual is not None:
x, residual = func(self, x, residual)
return x.add_(self.bias), residual
return func(self, x).add_(self.bias)
return _rmsnorm_forward_oot
MODEL_LAYER_MAPPING = {
"LlamaModel": {
"attn": {
"layer_attr": "self_attn",
"proj_attr": "qkv_proj",
"norm_attr": "input_layernorm",
"unquantized_type": UnquantizedLinearMethod,
},
"mlp": {
"layer_attr": "mlp",
"proj_attr": "gate_up_proj",
"norm_attr": "post_attention_layernorm",
"unquantized_type": UnquantizedLinearMethod,
},
},
}
def wrapper_load_model(func):
def postprocess_loading(self) -> None:
func(self)
def process_layer(layer, idx, mapping):
def process_module(module_cfg, layer_obj):
if module_cfg is None:
return
module_obj = getattr(layer_obj, module_cfg["layer_attr"], None)
if module_obj is None:
return
proj_attr = module_cfg["proj_attr"]
if callable(proj_attr):
proj = proj_attr(module_obj, idx)
else:
proj = getattr(module_obj, proj_attr, None)
norm = getattr(layer_obj, module_cfg["norm_attr"], None)
if proj is None or norm is None:
return
norm.ignore_anti = isinstance(proj.quant_method,
module_cfg["unquantized_type"])
if not norm.ignore_anti:
for param_name in ["input_scale", "input_offset"]:
if hasattr(proj, param_name):
param = getattr(proj, param_name)
norm.register_parameter(
param_name,
torch.nn.Parameter(param.clone(),
requires_grad=False))
process_module(mapping.get("attn"), layer)
process_module(mapping.get("mlp"), layer)
model_type = self.model.model.__class__.__name__
mapping = MODEL_LAYER_MAPPING.get(model_type)
if not mapping:
logger.info(
f"Warning: Model type '{model_type}' not found in MODEL_LAYER_MAPPING. Skipping layer mapping."
)
return
for idx, layer in enumerate(self.model.model.layers):
process_layer(layer, idx, mapping)
if isinstance(self.model.model.norm, RMSNorm):
self.model.model.norm.ignore_anti = True
return postprocess_loading