[BugFix] Fix bugs when using ascend quantization (#275)
### What this PR does / why we need it? It fixes following bugs: 1. When searching a specific linear quantization implementation from a tool (such as MindIE-Turbo), the mapping of packed linear is required to identify correponding quant type. 2. The exception is narrowed down to ImportError when importing MindIETurboQuantizer to better throw other errors. 3. The api of AscendKVCacheMethod.apply is aligned with that in AscendAttentionBackendImpl. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By performing offline inference:  --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -744,10 +744,19 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
|
block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None
|
||||||
# Details of kv_cache arrangement in attention quantization
|
# Details of kv_cache arrangement in attention quantization
|
||||||
# are implemented by quant_method.
|
# are implemented by quant_method.
|
||||||
layer.quant_method.apply(layer, query, key, value, self.key_cache,
|
layer.quant_method.apply(
|
||||||
self.value_cache, self.scale,
|
layer,
|
||||||
self.seq_lens_tensor_cpu, block_tables,
|
query,
|
||||||
isPrefill, attn_metadata, output)
|
key,
|
||||||
|
value,
|
||||||
|
self.key_cache,
|
||||||
|
self.value_cache,
|
||||||
|
self.scale,
|
||||||
|
block_tables,
|
||||||
|
isPrefill,
|
||||||
|
attn_metadata,
|
||||||
|
output,
|
||||||
|
seq_lens_tensor_cpu=self.seq_lens_tensor_cpu)
|
||||||
else:
|
else:
|
||||||
if self.key_cache is not None:
|
if self.key_cache is not None:
|
||||||
torch_npu._npu_reshape_and_cache(key=key,
|
torch_npu._npu_reshape_and_cache(key=key,
|
||||||
|
|||||||
@@ -88,7 +88,8 @@ class AscendQuantConfig(QuantizationConfig):
|
|||||||
if self.is_layer_skipped_ascend(prefix,
|
if self.is_layer_skipped_ascend(prefix,
|
||||||
self.packed_modules_mapping):
|
self.packed_modules_mapping):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return AscendLinearMethod(self, prefix)
|
return AscendLinearMethod(self, prefix,
|
||||||
|
self.packed_modules_mapping)
|
||||||
if isinstance(layer, Attention) and \
|
if isinstance(layer, Attention) and \
|
||||||
'fa_quant_type' in self.quant_description.keys():
|
'fa_quant_type' in self.quant_description.keys():
|
||||||
return AscendKVCacheMethod(self, prefix)
|
return AscendKVCacheMethod(self, prefix)
|
||||||
@@ -138,9 +139,10 @@ class AscendLinearMethod(LinearMethodBase):
|
|||||||
quant_config: The Ascend quantization config.
|
quant_config: The Ascend quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||||
|
packed_modules_mapping: Dict[str, Any]) -> None:
|
||||||
self.quantizer = AscendQuantizer.get_quantizer(
|
self.quantizer = AscendQuantizer.get_quantizer(
|
||||||
quant_config.quant_description, prefix)
|
quant_config.quant_description, prefix, packed_modules_mapping)
|
||||||
self.quant_method = self.quantizer.build_linear_method()
|
self.quant_method = self.quantizer.build_linear_method()
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
@@ -225,12 +227,29 @@ class AscendKVCacheMethod(BaseKVCacheMethod):
|
|||||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||||
self.quant_method.process_weights_after_loading(layer)
|
self.quant_method.process_weights_after_loading(layer)
|
||||||
|
|
||||||
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
def apply(self,
|
||||||
key: torch.Tensor, value: torch.Tensor,
|
layer: torch.nn.Module,
|
||||||
kv_cache: List[torch.Tensor], scale: torch.Tensor,
|
query: torch.Tensor,
|
||||||
seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
|
key: torch.Tensor,
|
||||||
isPrefill: bool, attn_metadata, output) -> torch.Tensor:
|
value: torch.Tensor,
|
||||||
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
k_cache: List[torch.Tensor],
|
||||||
scale, seq_lens_tensor_cpu,
|
v_cache: List[torch.Tensor],
|
||||||
block_tables, isPrefill, attn_metadata,
|
scale: torch.Tensor,
|
||||||
output)
|
block_tables: torch.Tensor,
|
||||||
|
isPrefill: bool,
|
||||||
|
attn_metadata,
|
||||||
|
output,
|
||||||
|
seq_lens_tensor_cpu: Optional[int] = None) -> torch.Tensor:
|
||||||
|
return self.quant_method.apply(layer,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
isPrefill,
|
||||||
|
attn_metadata.attn_mask,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
output,
|
||||||
|
seq_lens_tensor_cpu=seq_lens_tensor_cpu)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
CUSTOMIZED_QUANTIZER_TYPE: List[str] = []
|
CUSTOMIZED_QUANTIZER_TYPE: List[str] = []
|
||||||
|
|
||||||
@@ -25,7 +25,11 @@ class AscendQuantizer:
|
|||||||
"""An interface to different quantization implementations for ascend hardwares."""
|
"""An interface to different quantization implementations for ascend hardwares."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_quantizer(cls, quant_config: Dict[str, Any], prefix: str):
|
def get_quantizer(cls,
|
||||||
|
quant_config: Dict[str, Any],
|
||||||
|
prefix: str,
|
||||||
|
packed_modules_mapping: Optional[Dict[str,
|
||||||
|
Any]] = dict()):
|
||||||
# TODO: Need a param to choose quantization algorithms.
|
# TODO: Need a param to choose quantization algorithms.
|
||||||
quantization_algorithm = ''
|
quantization_algorithm = ''
|
||||||
|
|
||||||
@@ -35,11 +39,12 @@ class AscendQuantizer:
|
|||||||
try:
|
try:
|
||||||
module = importlib.import_module("mindie_turbo")
|
module = importlib.import_module("mindie_turbo")
|
||||||
MindIETurboQuantizer = module.MindIETurboQuantizer
|
MindIETurboQuantizer = module.MindIETurboQuantizer
|
||||||
except Exception:
|
except ImportError:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"There is no available ascend quantizer.")
|
"There is no available ascend quantizer.")
|
||||||
|
|
||||||
return MindIETurboQuantizer.get_quantizer(quant_config, prefix)
|
return MindIETurboQuantizer.get_quantizer(quant_config, prefix,
|
||||||
|
packed_modules_mapping)
|
||||||
|
|
||||||
def build_linear_method(self):
|
def build_linear_method(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Reference in New Issue
Block a user