Files
xc-llm-ascend/vllm_ascend/quantization/func_wrapper.py
liu 688350a3bb [bugfixed] fix the bug when run the inference of quantized ds-w8a8-mtp (#2134)
When run the inference of ds-w8a8-mtp, it reported 'ParamllelLMhead has
no attribute 'params_dtype''.

1. add wrapper of vocab_parallel_embedding, fixed the bugs when running
deepseek-w8a8-mtp

Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>

- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a

---------

Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
2025-08-04 15:16:42 +08:00

185 lines
5.9 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Tuple, Union
import torch
import torch_npu
from vllm.logger import logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig)
# func refers to vocabParallelEmbedding.__init__
def wrapper_vocab_parallel_embedding_init(func):
def init(
self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
func(
self,
num_embeddings,
embedding_dim,
params_dtype,
org_num_embeddings,
padding_size,
quant_config,
prefix,
)
# TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class.
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
return init
# func refers to RMSNorm.__init__
def wrapper_rmsnorm_init(func):
def init(self, hidden_size: int, **extra_args) -> None:
func(self, hidden_size, **extra_args)
self.ignore_anti = True
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
return init
# func refers to RMSNorm.forward_oot
def wrapper_rmsnorm_forward_oot(func):
def _rmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not self.ignore_anti:
if residual is not None:
residual += x
out = torch_npu._npu_quant_rms_norm(
residual,
self.weight,
self.bias,
self.input_scale,
self.input_offset,
self.variance_epsilon,
)
return out, residual
out = torch_npu._npu_quant_rms_norm(
x,
self.weight,
self.bias,
self.input_scale,
self.input_offset,
self.variance_epsilon,
)
return out
if residual is not None:
x, residual = func(self, x, residual)
return x.add_(self.bias), residual
return func(self, x).add_(self.bias)
return _rmsnorm_forward_oot
MODEL_LAYER_MAPPING = {
"LlamaModel": {
"attn": {
"layer_attr": "self_attn",
"proj_attr": "qkv_proj",
"norm_attr": "input_layernorm",
"unquantized_type": UnquantizedLinearMethod,
},
"mlp": {
"layer_attr": "mlp",
"proj_attr": "gate_up_proj",
"norm_attr": "post_attention_layernorm",
"unquantized_type": UnquantizedLinearMethod,
},
},
}
def wrapper_load_model(func):
def postprocess_loading(self) -> None:
func(self)
def process_layer(layer, idx, mapping):
def process_module(module_cfg, layer_obj):
if module_cfg is None:
return
module_obj = getattr(layer_obj, module_cfg["layer_attr"], None)
if module_obj is None:
return
proj_attr = module_cfg["proj_attr"]
if callable(proj_attr):
proj = proj_attr(module_obj, idx)
else:
proj = getattr(module_obj, proj_attr, None)
norm = getattr(layer_obj, module_cfg["norm_attr"], None)
if proj is None or norm is None:
return
norm.ignore_anti = isinstance(proj.quant_method,
module_cfg["unquantized_type"])
if not norm.ignore_anti:
for param_name in ["input_scale", "input_offset"]:
if hasattr(proj, param_name):
param = getattr(proj, param_name)
norm.register_parameter(
param_name,
torch.nn.Parameter(param.clone(),
requires_grad=False))
process_module(mapping.get("attn"), layer)
process_module(mapping.get("mlp"), layer)
model_type = self.model.model.__class__.__name__
mapping = MODEL_LAYER_MAPPING.get(model_type)
if not mapping:
logger.info(
f"Warning: Model type '{model_type}' not found in MODEL_LAYER_MAPPING. Skipping layer mapping."
)
return
for idx, layer in enumerate(self.model.model.layers):
process_layer(layer, idx, mapping)
if isinstance(self.model.model.norm, RMSNorm):
self.model.model.norm.ignore_anti = True
return postprocess_loading