Cherry pick from 0.7.1 to keep the main code newest Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
257 lines
10 KiB
Python
257 lines
10 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
from types import MappingProxyType
|
|
from typing import Any, Dict, List, Mapping, Optional
|
|
|
|
import torch
|
|
import torch_npu # noqa: F401
|
|
from vllm.distributed import get_tensor_model_parallel_rank
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|
RowParallelLinear,
|
|
UnquantizedLinearMethod)
|
|
from vllm.model_executor.layers.quantization import \
|
|
register_quantization_config
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig, QuantizeMethodBase)
|
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
|
ChannelQuantScaleParameter,
|
|
ModelWeightParameter)
|
|
|
|
from .quantizer import AscendQuantizer
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@register_quantization_config("ascend")
|
|
class AscendQuantConfig(QuantizationConfig):
|
|
"""Config class for Ascend"""
|
|
|
|
def __init__(self, quant_config: Dict[str, Any]):
|
|
self.quant_description = quant_config
|
|
|
|
def __repr__(self) -> str:
|
|
return "AscendQuantConfig:\n" + super().__repr__()
|
|
|
|
@classmethod
|
|
def get_name(cls) -> str:
|
|
return "ascend"
|
|
|
|
@classmethod
|
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
|
return [torch.int8, torch.float16, torch.bfloat16]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
raise NotImplementedError(
|
|
"Ascend hardware dose not support \"get_min_capability\" feature.")
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> List[str]:
|
|
return []
|
|
|
|
@classmethod
|
|
def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig":
|
|
return cls(config)
|
|
|
|
@classmethod
|
|
def override_quantization_method(cls, hf_quant_cfg,
|
|
user_quant) -> Optional[str]:
|
|
if torch.npu.is_available():
|
|
return "ascend"
|
|
return None
|
|
|
|
def get_quant_method(self, layer: torch.nn.Module,
|
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
from vllm.attention.layer import Attention
|
|
if isinstance(layer, LinearBase):
|
|
if self.is_layer_skipped_ascend(prefix,
|
|
self.packed_modules_mapping):
|
|
return UnquantizedLinearMethod()
|
|
return AscendLinearMethod(self)
|
|
if isinstance(layer, Attention) and \
|
|
'fa_quant_type' in self.quant_description.keys():
|
|
return AscendQKVQuantAttentionMethod(self)
|
|
return None
|
|
|
|
def is_layer_skipped_ascend(
|
|
self,
|
|
prefix: str,
|
|
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})):
|
|
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
|
|
proj_name = prefix.split(".")[-1]
|
|
if proj_name in fused_mapping:
|
|
shard_prefixes = [
|
|
prefix.replace(proj_name, shard_proj_name)
|
|
for shard_proj_name in fused_mapping[proj_name]
|
|
]
|
|
|
|
is_skipped = None
|
|
for shard_prefix in shard_prefixes:
|
|
is_shard_skipped = self.quant_description[shard_prefix +
|
|
'.weight'] == "FLOAT"
|
|
|
|
if is_skipped is None:
|
|
is_skipped = is_shard_skipped
|
|
elif is_shard_skipped != is_skipped:
|
|
raise ValueError(
|
|
f"Detected some but not all shards of {prefix} "
|
|
"are quantized. All shards of fused layers "
|
|
"to have the same precision.")
|
|
else:
|
|
is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT"
|
|
|
|
assert is_skipped is not None
|
|
return is_skipped
|
|
|
|
def get_scaled_act_names(self) -> List[str]:
|
|
return []
|
|
|
|
|
|
class AscendLinearMethod(LinearMethodBase):
|
|
"""Linear method for Ascend quantization.
|
|
|
|
Args:
|
|
quant_config: The Ascend quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: AscendQuantConfig) -> None:
|
|
self.quantizer = AscendQuantizer.get_quantizer(
|
|
quant_config.quant_description)
|
|
self.quant_method = self.quantizer.build_linear_method()
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: List[int],
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
) -> None:
|
|
del output_size
|
|
output_size_per_partition = sum(output_partition_sizes)
|
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
|
|
|
weights = self.quant_method.create_weights(input_size_per_partition,
|
|
output_size_per_partition,
|
|
params_dtype)
|
|
|
|
weight_name = self.quant_method.get_weight()
|
|
if weight_name in weights.keys():
|
|
layer.register_parameter(
|
|
weight_name,
|
|
ModelWeightParameter(data=weights[weight_name].transpose(0, 1),
|
|
input_dim=1,
|
|
output_dim=0,
|
|
weight_loader=weight_loader))
|
|
else:
|
|
raise ValueError(
|
|
f"{weight_name} is nor registered. Please check your linear quant method implementation."
|
|
)
|
|
|
|
pertensor_names = self.quant_method.get_pertensor_param()
|
|
for pertensor_name in pertensor_names:
|
|
if pertensor_name in weights.keys():
|
|
param = BasevLLMParameter(data=weights[pertensor_name],
|
|
weight_loader=weight_loader)
|
|
# disable warning
|
|
param.ignore_warning = True
|
|
layer.register_parameter(pertensor_name, param)
|
|
else:
|
|
raise ValueError(
|
|
f"{pertensor_name} is nor registered. Please check your linear quant method implementation."
|
|
)
|
|
|
|
perchannel_names = self.quant_method.get_perchannel_param()
|
|
for perchannel_name in perchannel_names:
|
|
if perchannel_name in weights.keys():
|
|
layer.register_parameter(
|
|
perchannel_name,
|
|
ChannelQuantScaleParameter(data=weights[perchannel_name],
|
|
output_dim=0,
|
|
weight_loader=weight_loader))
|
|
else:
|
|
raise ValueError(
|
|
f"{perchannel_name} is nor registered. Please check your linear quant method implementation."
|
|
)
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
if hasattr(self.quant_method,
|
|
'transpose_weight') and self.quant_method.transpose_weight:
|
|
layer.weight.data = layer.weight.data.transpose(1, 0)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if isinstance(layer, RowParallelLinear):
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
return self.quant_method.apply(layer, x, bias, tp_rank)
|
|
return self.quant_method.apply(layer, x, bias)
|
|
|
|
|
|
class AscendQKVQuantAttentionMethod(BaseKVCacheMethod):
|
|
"""Linear method for Ascend quantization.
|
|
|
|
Args:
|
|
quant_config: The Ascend quantization config.
|
|
"""
|
|
|
|
def __init__(self, quant_config: AscendQuantConfig) -> None:
|
|
self.quantizer = AscendQuantizer.get_quantizer(
|
|
quant_config.quant_description)
|
|
self.quant_method = self.quantizer.build_attention_method()
|
|
|
|
def create_weights(self, layer: torch.nn.Module) -> None:
|
|
# ascend attention quantization might include some extra weights
|
|
# and must be loaded by dummy modules
|
|
extra_module_names = self.quant_method.get_extra_module_names()
|
|
for name in extra_module_names:
|
|
setattr(layer, name, torch.nn.Module())
|
|
|
|
# During model initialization, the default dtype is set as the model
|
|
# weight and activation dtype.
|
|
dtype = torch.get_default_dtype()
|
|
weights = self.quant_method.create_weights(dtype, layer.num_heads,
|
|
layer.num_kv_heads)
|
|
|
|
for name, weight in weights.items():
|
|
module_name, weight_name = name.split('.')
|
|
module = getattr(layer, module_name)
|
|
module.register_parameter(
|
|
weight_name, torch.nn.Parameter(weight, requires_grad=False))
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
|
self.quant_method.process_weights_after_loading(layer)
|
|
|
|
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
|
key: torch.Tensor, value: torch.Tensor,
|
|
kv_cache: List[torch.Tensor], scale: torch.Tensor,
|
|
seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
|
|
isPrefill: bool, attn_metadata, output) -> torch.Tensor:
|
|
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
|
scale, seq_lens_tensor_cpu,
|
|
block_tables, isPrefill, attn_metadata,
|
|
output)
|