152 lines
4.8 KiB
Python
152 lines
4.8 KiB
Python
|
|
#
|
||
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||
|
|
# This file is a part of the vllm-ascend project.
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
#
|
||
|
|
|
||
|
|
from typing import Optional, Tuple, Union
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch_npu
|
||
|
|
from vllm.logger import logger
|
||
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||
|
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||
|
|
|
||
|
|
|
||
|
|
# func refers to RMSNorm.__init__
|
||
|
|
def wrapper_rmsnorm_init(func):
|
||
|
|
|
||
|
|
def init(self, hidden_size: int, **extra_args) -> None:
|
||
|
|
func(self, hidden_size, **extra_args)
|
||
|
|
self.ignore_anti = True
|
||
|
|
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||
|
|
requires_grad=False)
|
||
|
|
|
||
|
|
return init
|
||
|
|
|
||
|
|
|
||
|
|
# func refers to RMSNorm.forward_oot
|
||
|
|
def wrapper_rmsnorm_forward_oot(func):
|
||
|
|
|
||
|
|
def _rmsnorm_forward_oot(
|
||
|
|
self,
|
||
|
|
x: torch.Tensor,
|
||
|
|
residual: Optional[torch.Tensor] = None,
|
||
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||
|
|
if not self.ignore_anti:
|
||
|
|
if residual is not None:
|
||
|
|
residual += x
|
||
|
|
out = torch_npu._npu_quant_rms_norm(
|
||
|
|
residual,
|
||
|
|
self.weight,
|
||
|
|
self.bias,
|
||
|
|
self.input_scale,
|
||
|
|
self.input_offset,
|
||
|
|
self.variance_epsilon,
|
||
|
|
)
|
||
|
|
return out, residual
|
||
|
|
out = torch_npu._npu_quant_rms_norm(
|
||
|
|
x,
|
||
|
|
self.weight,
|
||
|
|
self.bias,
|
||
|
|
self.input_scale,
|
||
|
|
self.input_offset,
|
||
|
|
self.variance_epsilon,
|
||
|
|
)
|
||
|
|
return out
|
||
|
|
|
||
|
|
if residual is not None:
|
||
|
|
x, residual = func(self, x, residual)
|
||
|
|
return x.add_(self.bias), residual
|
||
|
|
|
||
|
|
return func(self, x).add_(self.bias)
|
||
|
|
|
||
|
|
return _rmsnorm_forward_oot
|
||
|
|
|
||
|
|
|
||
|
|
MODEL_LAYER_MAPPING = {
|
||
|
|
"LlamaModel": {
|
||
|
|
"attn": {
|
||
|
|
"layer_attr": "self_attn",
|
||
|
|
"proj_attr": "qkv_proj",
|
||
|
|
"norm_attr": "input_layernorm",
|
||
|
|
"unquantized_type": UnquantizedLinearMethod,
|
||
|
|
},
|
||
|
|
"mlp": {
|
||
|
|
"layer_attr": "mlp",
|
||
|
|
"proj_attr": "gate_up_proj",
|
||
|
|
"norm_attr": "post_attention_layernorm",
|
||
|
|
"unquantized_type": UnquantizedLinearMethod,
|
||
|
|
},
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def wrapper_load_model(func):
|
||
|
|
|
||
|
|
def postprocess_loading(self) -> None:
|
||
|
|
func(self)
|
||
|
|
|
||
|
|
def process_layer(layer, idx, mapping):
|
||
|
|
|
||
|
|
def process_module(module_cfg, layer_obj):
|
||
|
|
if module_cfg is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
module_obj = getattr(layer_obj, module_cfg["layer_attr"], None)
|
||
|
|
if module_obj is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
proj_attr = module_cfg["proj_attr"]
|
||
|
|
if callable(proj_attr):
|
||
|
|
proj = proj_attr(module_obj, idx)
|
||
|
|
else:
|
||
|
|
proj = getattr(module_obj, proj_attr, None)
|
||
|
|
|
||
|
|
norm = getattr(layer_obj, module_cfg["norm_attr"], None)
|
||
|
|
|
||
|
|
if proj is None or norm is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
norm.ignore_anti = isinstance(proj.quant_method,
|
||
|
|
module_cfg["unquantized_type"])
|
||
|
|
if not norm.ignore_anti:
|
||
|
|
for param_name in ["input_scale", "input_offset"]:
|
||
|
|
if hasattr(proj, param_name):
|
||
|
|
param = getattr(proj, param_name)
|
||
|
|
norm.register_parameter(
|
||
|
|
param_name,
|
||
|
|
torch.nn.Parameter(param.clone(),
|
||
|
|
requires_grad=False))
|
||
|
|
|
||
|
|
process_module(mapping.get("attn"), layer)
|
||
|
|
process_module(mapping.get("mlp"), layer)
|
||
|
|
|
||
|
|
model_type = self.model.model.__class__.__name__
|
||
|
|
mapping = MODEL_LAYER_MAPPING.get(model_type)
|
||
|
|
|
||
|
|
if not mapping:
|
||
|
|
logger.info(
|
||
|
|
f"Warning: Model type '{model_type}' not found in MODEL_LAYER_MAPPING. Skipping layer mapping."
|
||
|
|
)
|
||
|
|
return
|
||
|
|
|
||
|
|
for idx, layer in enumerate(self.model.model.layers):
|
||
|
|
process_layer(layer, idx, mapping)
|
||
|
|
|
||
|
|
if isinstance(self.model.model.norm, RMSNorm):
|
||
|
|
self.model.model.norm.ignore_anti = True
|
||
|
|
|
||
|
|
return postprocess_loading
|