From 3217f0d10fbbc6e6cc8b0db9594b8cef515b4f90 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Thu, 6 Mar 2025 15:17:25 +0800 Subject: [PATCH] [Feature] Modify description and api for ascend quantization (#243) ### What this PR does / why we need it? 1. It adds more description for classes in quant_config.py 2. It renames AscendQKVQuantAttentionMethod to AscendKVCacheMethod to align with vLLM naming style. 3. It modifies the process when AscendLinearMethod or AscendKVCacheMethod calls create_weights. ### Does this PR introduce _any_ user-facing change? Yes. When creating weights, now AscendLinearMethod uses get_weight, get_pertensor_param and get_perchannel_param api from linear quant implementation, while AscendKVCacheMethod passes layer into linear quant implementation. ### How was this patch tested? By performing offline inference --------- Signed-off-by: angazenn Co-authored-by: angazenn --- vllm_ascend/quantization/quant_config.py | 120 ++++++++++------------- vllm_ascend/quantization/quantizer.py | 4 +- 2 files changed, 52 insertions(+), 72 deletions(-) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 7fb2622..3130142 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -30,9 +30,9 @@ from vllm.model_executor.layers.quantization import \ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.parameter import (BasevLLMParameter, - ChannelQuantScaleParameter, - ModelWeightParameter) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) from .quantizer import AscendQuantizer @@ -41,7 +41,11 @@ logger = init_logger(__name__) @register_quantization_config("ascend") class AscendQuantConfig(QuantizationConfig): - """Config class for Ascend""" + """Config class for Ascend + + This class is a general class that parse quantization configs + that are supported on ascend hardware. + """ def __init__(self, quant_config: Dict[str, Any]): self.quant_description = quant_config @@ -84,10 +88,10 @@ class AscendQuantConfig(QuantizationConfig): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): return UnquantizedLinearMethod() - return AscendLinearMethod(self) + return AscendLinearMethod(self, prefix) if isinstance(layer, Attention) and \ 'fa_quant_type' in self.quant_description.keys(): - return AscendQKVQuantAttentionMethod(self) + return AscendKVCacheMethod(self, prefix) return None def is_layer_skipped_ascend( @@ -127,13 +131,16 @@ class AscendQuantConfig(QuantizationConfig): class AscendLinearMethod(LinearMethodBase): """Linear method for Ascend quantization. + This class calls AscendQuantizer to search a specific quantization + implementations supported on ascend hardware for linear methods. + Args: quant_config: The Ascend quantization config. """ - def __init__(self, quant_config: AscendQuantConfig) -> None: + def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None: self.quantizer = AscendQuantizer.get_quantizer( - quant_config.quant_description) + quant_config.quant_description, prefix) self.quant_method = self.quantizer.build_linear_method() def create_weights( @@ -146,57 +153,40 @@ class AscendLinearMethod(LinearMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: - del output_size output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") - weights = self.quant_method.create_weights(input_size_per_partition, + weight_dict = self.quant_method.get_weight(input_size_per_partition, output_size_per_partition, params_dtype) - - weight_name = self.quant_method.get_weight() - if weight_name in weights.keys(): + for weight_name, weight_param in weight_dict.items(): layer.register_parameter( weight_name, - ModelWeightParameter(data=weights[weight_name].transpose(0, 1), + ModelWeightParameter(data=weight_param, input_dim=1, output_dim=0, weight_loader=weight_loader)) - else: - raise ValueError( - f"{weight_name} is nor registered. Please check your linear quant method implementation." - ) - pertensor_names = self.quant_method.get_pertensor_param() - for pertensor_name in pertensor_names: - if pertensor_name in weights.keys(): - param = BasevLLMParameter(data=weights[pertensor_name], - weight_loader=weight_loader) - # disable warning - param.ignore_warning = True - layer.register_parameter(pertensor_name, param) - else: - raise ValueError( - f"{pertensor_name} is nor registered. Please check your linear quant method implementation." - ) + pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) + for pertensor_name, pertensor_param in pertensor_dict.items(): + param = PerTensorScaleParameter(data=pertensor_param, + weight_loader=weight_loader) + # disable warning + param.ignore_warning = True + layer.register_parameter(pertensor_name, param) - perchannel_names = self.quant_method.get_perchannel_param() - for perchannel_name in perchannel_names: - if perchannel_name in weights.keys(): - layer.register_parameter( - perchannel_name, - ChannelQuantScaleParameter(data=weights[perchannel_name], - output_dim=0, - weight_loader=weight_loader)) - else: - raise ValueError( - f"{perchannel_name} is nor registered. Please check your linear quant method implementation." - ) + perchannel_dict = self.quant_method.get_perchannel_param( + output_size_per_partition, params_dtype) + for perchannel_name, perchannel_param in perchannel_dict.items(): + layer.register_parameter( + perchannel_name, + ChannelQuantScaleParameter(data=perchannel_param, + output_dim=0, + weight_loader=weight_loader)) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if hasattr(self.quant_method, - 'transpose_weight') and self.quant_method.transpose_weight: - layer.weight.data = layer.weight.data.transpose(1, 0) + if hasattr(self.quant_method, "process_weights_after_loading"): + self.quant_method.process_weights_after_loading(layer) def apply( self, @@ -210,47 +200,37 @@ class AscendLinearMethod(LinearMethodBase): return self.quant_method.apply(layer, x, bias) -class AscendQKVQuantAttentionMethod(BaseKVCacheMethod): - """Linear method for Ascend quantization. +class AscendKVCacheMethod(BaseKVCacheMethod): + """KVCache method for Ascend quantization. + + This class calls AscendQuantizer to search a specific quantization + implementations supported on ascend hardware for kvcache methods. Args: quant_config: The Ascend quantization config. """ - def __init__(self, quant_config: AscendQuantConfig) -> None: + def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None: self.quantizer = AscendQuantizer.get_quantizer( - quant_config.quant_description) + quant_config.quant_description, prefix) self.quant_method = self.quantizer.build_attention_method() def create_weights(self, layer: torch.nn.Module) -> None: - # ascend attention quantization might include some extra weights - # and must be loaded by dummy modules - extra_module_names = self.quant_method.get_extra_module_names() - for name in extra_module_names: - setattr(layer, name, torch.nn.Module()) - - # During model initialization, the default dtype is set as the model - # weight and activation dtype. - dtype = torch.get_default_dtype() - weights = self.quant_method.create_weights(dtype, layer.num_heads, - layer.num_kv_heads) - - for name, weight in weights.items(): - module_name, weight_name = name.split('.') - module = getattr(layer, module_name) - module.register_parameter( - weight_name, torch.nn.Parameter(weight, requires_grad=False)) + # Different from linear method, there are no weight processing/slicing + # steps for attention in vllm. So the whole process of create weights + # is hidden into the specific quant method. + self.quant_method.create_weights(layer) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) def apply(self, layer: torch.nn.Module, query: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor, - value_cache: torch.Tensor, scale: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + kv_cache: List[torch.Tensor], scale: torch.Tensor, seq_lens_tensor_cpu: int, block_tables: torch.Tensor, isPrefill: bool, attn_metadata, output) -> torch.Tensor: - return self.quant_method.apply(layer, query, key, value, key_cache, - value_cache, scale, seq_lens_tensor_cpu, + return self.quant_method.apply(layer, query, key, value, kv_cache, + scale, seq_lens_tensor_cpu, block_tables, isPrefill, attn_metadata, output) diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index f6cc450..b7c8fe9 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -25,7 +25,7 @@ class AscendQuantizer: """An interface to different quantization implementations for ascend hardwares.""" @classmethod - def get_quantizer(cls, quant_config: Dict[str, Any]): + def get_quantizer(cls, quant_config: Dict[str, Any], prefix: str): # TODO: Need a param to choose quantization algorithms. quantization_algorithm = '' @@ -39,7 +39,7 @@ class AscendQuantizer: raise NotImplementedError( "There is no available ascend quantizer.") - return MindIETurboQuantizer.get_quantizer(quant_config) + return MindIETurboQuantizer.get_quantizer(quant_config, prefix) def build_linear_method(self): raise NotImplementedError