[Feature] Modify description and api for ascend quantization (#243)
### What this PR does / why we need it? 1. It adds more description for classes in quant_config.py 2. It renames AscendQKVQuantAttentionMethod to AscendKVCacheMethod to align with vLLM naming style. 3. It modifies the process when AscendLinearMethod or AscendKVCacheMethod calls create_weights. ### Does this PR introduce _any_ user-facing change? Yes. When creating weights, now AscendLinearMethod uses get_weight, get_pertensor_param and get_perchannel_param api from linear quant implementation, while AscendKVCacheMethod passes layer into linear quant implementation. ### How was this patch tested? By performing offline inference --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -30,9 +30,9 @@ from vllm.model_executor.layers.quantization import \
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
ChannelQuantScaleParameter,
|
ModelWeightParameter,
|
||||||
ModelWeightParameter)
|
PerTensorScaleParameter)
|
||||||
|
|
||||||
from .quantizer import AscendQuantizer
|
from .quantizer import AscendQuantizer
|
||||||
|
|
||||||
@@ -41,7 +41,11 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
@register_quantization_config("ascend")
|
@register_quantization_config("ascend")
|
||||||
class AscendQuantConfig(QuantizationConfig):
|
class AscendQuantConfig(QuantizationConfig):
|
||||||
"""Config class for Ascend"""
|
"""Config class for Ascend
|
||||||
|
|
||||||
|
This class is a general class that parse quantization configs
|
||||||
|
that are supported on ascend hardware.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: Dict[str, Any]):
|
def __init__(self, quant_config: Dict[str, Any]):
|
||||||
self.quant_description = quant_config
|
self.quant_description = quant_config
|
||||||
@@ -84,10 +88,10 @@ class AscendQuantConfig(QuantizationConfig):
|
|||||||
if self.is_layer_skipped_ascend(prefix,
|
if self.is_layer_skipped_ascend(prefix,
|
||||||
self.packed_modules_mapping):
|
self.packed_modules_mapping):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return AscendLinearMethod(self)
|
return AscendLinearMethod(self, prefix)
|
||||||
if isinstance(layer, Attention) and \
|
if isinstance(layer, Attention) and \
|
||||||
'fa_quant_type' in self.quant_description.keys():
|
'fa_quant_type' in self.quant_description.keys():
|
||||||
return AscendQKVQuantAttentionMethod(self)
|
return AscendKVCacheMethod(self, prefix)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def is_layer_skipped_ascend(
|
def is_layer_skipped_ascend(
|
||||||
@@ -127,13 +131,16 @@ class AscendQuantConfig(QuantizationConfig):
|
|||||||
class AscendLinearMethod(LinearMethodBase):
|
class AscendLinearMethod(LinearMethodBase):
|
||||||
"""Linear method for Ascend quantization.
|
"""Linear method for Ascend quantization.
|
||||||
|
|
||||||
|
This class calls AscendQuantizer to search a specific quantization
|
||||||
|
implementations supported on ascend hardware for linear methods.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
quant_config: The Ascend quantization config.
|
quant_config: The Ascend quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: AscendQuantConfig) -> None:
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
|
||||||
self.quantizer = AscendQuantizer.get_quantizer(
|
self.quantizer = AscendQuantizer.get_quantizer(
|
||||||
quant_config.quant_description)
|
quant_config.quant_description, prefix)
|
||||||
self.quant_method = self.quantizer.build_linear_method()
|
self.quant_method = self.quantizer.build_linear_method()
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
@@ -146,57 +153,40 @@ class AscendLinearMethod(LinearMethodBase):
|
|||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
) -> None:
|
) -> None:
|
||||||
del output_size
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
|
||||||
weights = self.quant_method.create_weights(input_size_per_partition,
|
weight_dict = self.quant_method.get_weight(input_size_per_partition,
|
||||||
output_size_per_partition,
|
output_size_per_partition,
|
||||||
params_dtype)
|
params_dtype)
|
||||||
|
for weight_name, weight_param in weight_dict.items():
|
||||||
weight_name = self.quant_method.get_weight()
|
|
||||||
if weight_name in weights.keys():
|
|
||||||
layer.register_parameter(
|
layer.register_parameter(
|
||||||
weight_name,
|
weight_name,
|
||||||
ModelWeightParameter(data=weights[weight_name].transpose(0, 1),
|
ModelWeightParameter(data=weight_param,
|
||||||
input_dim=1,
|
input_dim=1,
|
||||||
output_dim=0,
|
output_dim=0,
|
||||||
weight_loader=weight_loader))
|
weight_loader=weight_loader))
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"{weight_name} is nor registered. Please check your linear quant method implementation."
|
|
||||||
)
|
|
||||||
|
|
||||||
pertensor_names = self.quant_method.get_pertensor_param()
|
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
||||||
for pertensor_name in pertensor_names:
|
for pertensor_name, pertensor_param in pertensor_dict.items():
|
||||||
if pertensor_name in weights.keys():
|
param = PerTensorScaleParameter(data=pertensor_param,
|
||||||
param = BasevLLMParameter(data=weights[pertensor_name],
|
weight_loader=weight_loader)
|
||||||
weight_loader=weight_loader)
|
# disable warning
|
||||||
# disable warning
|
param.ignore_warning = True
|
||||||
param.ignore_warning = True
|
layer.register_parameter(pertensor_name, param)
|
||||||
layer.register_parameter(pertensor_name, param)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"{pertensor_name} is nor registered. Please check your linear quant method implementation."
|
|
||||||
)
|
|
||||||
|
|
||||||
perchannel_names = self.quant_method.get_perchannel_param()
|
perchannel_dict = self.quant_method.get_perchannel_param(
|
||||||
for perchannel_name in perchannel_names:
|
output_size_per_partition, params_dtype)
|
||||||
if perchannel_name in weights.keys():
|
for perchannel_name, perchannel_param in perchannel_dict.items():
|
||||||
layer.register_parameter(
|
layer.register_parameter(
|
||||||
perchannel_name,
|
perchannel_name,
|
||||||
ChannelQuantScaleParameter(data=weights[perchannel_name],
|
ChannelQuantScaleParameter(data=perchannel_param,
|
||||||
output_dim=0,
|
output_dim=0,
|
||||||
weight_loader=weight_loader))
|
weight_loader=weight_loader))
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"{perchannel_name} is nor registered. Please check your linear quant method implementation."
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
if hasattr(self.quant_method,
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||||
'transpose_weight') and self.quant_method.transpose_weight:
|
self.quant_method.process_weights_after_loading(layer)
|
||||||
layer.weight.data = layer.weight.data.transpose(1, 0)
|
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@@ -210,47 +200,37 @@ class AscendLinearMethod(LinearMethodBase):
|
|||||||
return self.quant_method.apply(layer, x, bias)
|
return self.quant_method.apply(layer, x, bias)
|
||||||
|
|
||||||
|
|
||||||
class AscendQKVQuantAttentionMethod(BaseKVCacheMethod):
|
class AscendKVCacheMethod(BaseKVCacheMethod):
|
||||||
"""Linear method for Ascend quantization.
|
"""KVCache method for Ascend quantization.
|
||||||
|
|
||||||
|
This class calls AscendQuantizer to search a specific quantization
|
||||||
|
implementations supported on ascend hardware for kvcache methods.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
quant_config: The Ascend quantization config.
|
quant_config: The Ascend quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: AscendQuantConfig) -> None:
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
|
||||||
self.quantizer = AscendQuantizer.get_quantizer(
|
self.quantizer = AscendQuantizer.get_quantizer(
|
||||||
quant_config.quant_description)
|
quant_config.quant_description, prefix)
|
||||||
self.quant_method = self.quantizer.build_attention_method()
|
self.quant_method = self.quantizer.build_attention_method()
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module) -> None:
|
def create_weights(self, layer: torch.nn.Module) -> None:
|
||||||
# ascend attention quantization might include some extra weights
|
# Different from linear method, there are no weight processing/slicing
|
||||||
# and must be loaded by dummy modules
|
# steps for attention in vllm. So the whole process of create weights
|
||||||
extra_module_names = self.quant_method.get_extra_module_names()
|
# is hidden into the specific quant method.
|
||||||
for name in extra_module_names:
|
self.quant_method.create_weights(layer)
|
||||||
setattr(layer, name, torch.nn.Module())
|
|
||||||
|
|
||||||
# During model initialization, the default dtype is set as the model
|
|
||||||
# weight and activation dtype.
|
|
||||||
dtype = torch.get_default_dtype()
|
|
||||||
weights = self.quant_method.create_weights(dtype, layer.num_heads,
|
|
||||||
layer.num_kv_heads)
|
|
||||||
|
|
||||||
for name, weight in weights.items():
|
|
||||||
module_name, weight_name = name.split('.')
|
|
||||||
module = getattr(layer, module_name)
|
|
||||||
module.register_parameter(
|
|
||||||
weight_name, torch.nn.Parameter(weight, requires_grad=False))
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||||
self.quant_method.process_weights_after_loading(layer)
|
self.quant_method.process_weights_after_loading(layer)
|
||||||
|
|
||||||
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
||||||
key: torch.Tensor, value: torch.Tensor, key_cache: torch.Tensor,
|
key: torch.Tensor, value: torch.Tensor,
|
||||||
value_cache: torch.Tensor, scale: torch.Tensor,
|
kv_cache: List[torch.Tensor], scale: torch.Tensor,
|
||||||
seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
|
seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
|
||||||
isPrefill: bool, attn_metadata, output) -> torch.Tensor:
|
isPrefill: bool, attn_metadata, output) -> torch.Tensor:
|
||||||
return self.quant_method.apply(layer, query, key, value, key_cache,
|
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
||||||
value_cache, scale, seq_lens_tensor_cpu,
|
scale, seq_lens_tensor_cpu,
|
||||||
block_tables, isPrefill, attn_metadata,
|
block_tables, isPrefill, attn_metadata,
|
||||||
output)
|
output)
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class AscendQuantizer:
|
|||||||
"""An interface to different quantization implementations for ascend hardwares."""
|
"""An interface to different quantization implementations for ascend hardwares."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_quantizer(cls, quant_config: Dict[str, Any]):
|
def get_quantizer(cls, quant_config: Dict[str, Any], prefix: str):
|
||||||
# TODO: Need a param to choose quantization algorithms.
|
# TODO: Need a param to choose quantization algorithms.
|
||||||
quantization_algorithm = ''
|
quantization_algorithm = ''
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ class AscendQuantizer:
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"There is no available ascend quantizer.")
|
"There is no available ascend quantizer.")
|
||||||
|
|
||||||
return MindIETurboQuantizer.get_quantizer(quant_config)
|
return MindIETurboQuantizer.get_quantizer(quant_config, prefix)
|
||||||
|
|
||||||
def build_linear_method(self):
|
def build_linear_method(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Reference in New Issue
Block a user