[Feature] Modify description and api for ascend quantization (#243)

### What this PR does / why we need it?
1. It adds more description for classes in quant_config.py
2. It renames AscendQKVQuantAttentionMethod to AscendKVCacheMethod to
align with vLLM naming style.
3. It modifies the process when AscendLinearMethod or
AscendKVCacheMethod calls create_weights.


### Does this PR introduce _any_ user-facing change?
Yes. When creating weights, now AscendLinearMethod uses get_weight,
get_pertensor_param and get_perchannel_param api from linear quant
implementation, while AscendKVCacheMethod passes layer into linear quant
implementation.

### How was this patch tested?
By performing offline inference

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
Angazenn
2025-03-06 15:17:25 +08:00
committed by GitHub
parent cff08f9df8
commit 3217f0d10f
2 changed files with 52 additions and 72 deletions

View File

@@ -30,9 +30,9 @@ from vllm.model_executor.layers.quantization import \
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ChannelQuantScaleParameter, ModelWeightParameter,
ModelWeightParameter) PerTensorScaleParameter)
from .quantizer import AscendQuantizer from .quantizer import AscendQuantizer
@@ -41,7 +41,11 @@ logger = init_logger(__name__)
@register_quantization_config("ascend") @register_quantization_config("ascend")
class AscendQuantConfig(QuantizationConfig): class AscendQuantConfig(QuantizationConfig):
"""Config class for Ascend""" """Config class for Ascend
This class is a general class that parse quantization configs
that are supported on ascend hardware.
"""
def __init__(self, quant_config: Dict[str, Any]): def __init__(self, quant_config: Dict[str, Any]):
self.quant_description = quant_config self.quant_description = quant_config
@@ -84,10 +88,10 @@ class AscendQuantConfig(QuantizationConfig):
if self.is_layer_skipped_ascend(prefix, if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping): self.packed_modules_mapping):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return AscendLinearMethod(self) return AscendLinearMethod(self, prefix)
if isinstance(layer, Attention) and \ if isinstance(layer, Attention) and \
'fa_quant_type' in self.quant_description.keys(): 'fa_quant_type' in self.quant_description.keys():
return AscendQKVQuantAttentionMethod(self) return AscendKVCacheMethod(self, prefix)
return None return None
def is_layer_skipped_ascend( def is_layer_skipped_ascend(
@@ -127,13 +131,16 @@ class AscendQuantConfig(QuantizationConfig):
class AscendLinearMethod(LinearMethodBase): class AscendLinearMethod(LinearMethodBase):
"""Linear method for Ascend quantization. """Linear method for Ascend quantization.
This class calls AscendQuantizer to search a specific quantization
implementations supported on ascend hardware for linear methods.
Args: Args:
quant_config: The Ascend quantization config. quant_config: The Ascend quantization config.
""" """
def __init__(self, quant_config: AscendQuantConfig) -> None: def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
self.quantizer = AscendQuantizer.get_quantizer( self.quantizer = AscendQuantizer.get_quantizer(
quant_config.quant_description) quant_config.quant_description, prefix)
self.quant_method = self.quantizer.build_linear_method() self.quant_method = self.quantizer.build_linear_method()
def create_weights( def create_weights(
@@ -146,57 +153,40 @@ class AscendLinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
) -> None: ) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
weights = self.quant_method.create_weights(input_size_per_partition, weight_dict = self.quant_method.get_weight(input_size_per_partition,
output_size_per_partition, output_size_per_partition,
params_dtype) params_dtype)
for weight_name, weight_param in weight_dict.items():
weight_name = self.quant_method.get_weight()
if weight_name in weights.keys():
layer.register_parameter( layer.register_parameter(
weight_name, weight_name,
ModelWeightParameter(data=weights[weight_name].transpose(0, 1), ModelWeightParameter(data=weight_param,
input_dim=1, input_dim=1,
output_dim=0, output_dim=0,
weight_loader=weight_loader)) weight_loader=weight_loader))
else:
raise ValueError(
f"{weight_name} is nor registered. Please check your linear quant method implementation."
)
pertensor_names = self.quant_method.get_pertensor_param() pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name in pertensor_names: for pertensor_name, pertensor_param in pertensor_dict.items():
if pertensor_name in weights.keys(): param = PerTensorScaleParameter(data=pertensor_param,
param = BasevLLMParameter(data=weights[pertensor_name], weight_loader=weight_loader)
weight_loader=weight_loader) # disable warning
# disable warning param.ignore_warning = True
param.ignore_warning = True layer.register_parameter(pertensor_name, param)
layer.register_parameter(pertensor_name, param)
else:
raise ValueError(
f"{pertensor_name} is nor registered. Please check your linear quant method implementation."
)
perchannel_names = self.quant_method.get_perchannel_param() perchannel_dict = self.quant_method.get_perchannel_param(
for perchannel_name in perchannel_names: output_size_per_partition, params_dtype)
if perchannel_name in weights.keys(): for perchannel_name, perchannel_param in perchannel_dict.items():
layer.register_parameter( layer.register_parameter(
perchannel_name, perchannel_name,
ChannelQuantScaleParameter(data=weights[perchannel_name], ChannelQuantScaleParameter(data=perchannel_param,
output_dim=0, output_dim=0,
weight_loader=weight_loader)) weight_loader=weight_loader))
else:
raise ValueError(
f"{perchannel_name} is nor registered. Please check your linear quant method implementation."
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, if hasattr(self.quant_method, "process_weights_after_loading"):
'transpose_weight') and self.quant_method.transpose_weight: self.quant_method.process_weights_after_loading(layer)
layer.weight.data = layer.weight.data.transpose(1, 0)
def apply( def apply(
self, self,
@@ -210,47 +200,37 @@ class AscendLinearMethod(LinearMethodBase):
return self.quant_method.apply(layer, x, bias) return self.quant_method.apply(layer, x, bias)
class AscendQKVQuantAttentionMethod(BaseKVCacheMethod): class AscendKVCacheMethod(BaseKVCacheMethod):
"""Linear method for Ascend quantization. """KVCache method for Ascend quantization.
This class calls AscendQuantizer to search a specific quantization
implementations supported on ascend hardware for kvcache methods.
Args: Args:
quant_config: The Ascend quantization config. quant_config: The Ascend quantization config.
""" """
def __init__(self, quant_config: AscendQuantConfig) -> None: def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
self.quantizer = AscendQuantizer.get_quantizer( self.quantizer = AscendQuantizer.get_quantizer(
quant_config.quant_description) quant_config.quant_description, prefix)
self.quant_method = self.quantizer.build_attention_method() self.quant_method = self.quantizer.build_attention_method()
def create_weights(self, layer: torch.nn.Module) -> None: def create_weights(self, layer: torch.nn.Module) -> None:
# ascend attention quantization might include some extra weights # Different from linear method, there are no weight processing/slicing
# and must be loaded by dummy modules # steps for attention in vllm. So the whole process of create weights
extra_module_names = self.quant_method.get_extra_module_names() # is hidden into the specific quant method.
for name in extra_module_names: self.quant_method.create_weights(layer)
setattr(layer, name, torch.nn.Module())
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
weights = self.quant_method.create_weights(dtype, layer.num_heads,
layer.num_kv_heads)
for name, weight in weights.items():
module_name, weight_name = name.split('.')
module = getattr(layer, module_name)
module.register_parameter(
weight_name, torch.nn.Parameter(weight, requires_grad=False))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"): if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer) self.quant_method.process_weights_after_loading(layer)
def apply(self, layer: torch.nn.Module, query: torch.Tensor, def apply(self, layer: torch.nn.Module, query: torch.Tensor,
key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
value_cache: torch.Tensor, scale: torch.Tensor, kv_cache: List[torch.Tensor], scale: torch.Tensor,
seq_lens_tensor_cpu: int, block_tables: torch.Tensor, seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
isPrefill: bool, attn_metadata, output) -> torch.Tensor: isPrefill: bool, attn_metadata, output) -> torch.Tensor:
return self.quant_method.apply(layer, query, key, value, key_cache, return self.quant_method.apply(layer, query, key, value, kv_cache,
value_cache, scale, seq_lens_tensor_cpu, scale, seq_lens_tensor_cpu,
block_tables, isPrefill, attn_metadata, block_tables, isPrefill, attn_metadata,
output) output)

View File

@@ -25,7 +25,7 @@ class AscendQuantizer:
"""An interface to different quantization implementations for ascend hardwares.""" """An interface to different quantization implementations for ascend hardwares."""
@classmethod @classmethod
def get_quantizer(cls, quant_config: Dict[str, Any]): def get_quantizer(cls, quant_config: Dict[str, Any], prefix: str):
# TODO: Need a param to choose quantization algorithms. # TODO: Need a param to choose quantization algorithms.
quantization_algorithm = '' quantization_algorithm = ''
@@ -39,7 +39,7 @@ class AscendQuantizer:
raise NotImplementedError( raise NotImplementedError(
"There is no available ascend quantizer.") "There is no available ascend quantizer.")
return MindIETurboQuantizer.get_quantizer(quant_config) return MindIETurboQuantizer.get_quantizer(quant_config, prefix)
def build_linear_method(self): def build_linear_method(self):
raise NotImplementedError raise NotImplementedError