[BugFix] Fix bugs when using ascend quantization (#275)
### What this PR does / why we need it? It fixes following bugs: 1. When searching a specific linear quantization implementation from a tool (such as MindIE-Turbo), the mapping of packed linear is required to identify correponding quant type. 2. The exception is narrowed down to ImportError when importing MindIETurboQuantizer to better throw other errors. 3. The api of AscendKVCacheMethod.apply is aligned with that in AscendAttentionBackendImpl. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By performing offline inference:  --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -88,7 +88,8 @@ class AscendQuantConfig(QuantizationConfig):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
return AscendLinearMethod(self, prefix)
|
||||
return AscendLinearMethod(self, prefix,
|
||||
self.packed_modules_mapping)
|
||||
if isinstance(layer, Attention) and \
|
||||
'fa_quant_type' in self.quant_description.keys():
|
||||
return AscendKVCacheMethod(self, prefix)
|
||||
@@ -138,9 +139,10 @@ class AscendLinearMethod(LinearMethodBase):
|
||||
quant_config: The Ascend quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None:
|
||||
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||
packed_modules_mapping: Dict[str, Any]) -> None:
|
||||
self.quantizer = AscendQuantizer.get_quantizer(
|
||||
quant_config.quant_description, prefix)
|
||||
quant_config.quant_description, prefix, packed_modules_mapping)
|
||||
self.quant_method = self.quantizer.build_linear_method()
|
||||
|
||||
def create_weights(
|
||||
@@ -225,12 +227,29 @@ class AscendKVCacheMethod(BaseKVCacheMethod):
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
kv_cache: List[torch.Tensor], scale: torch.Tensor,
|
||||
seq_lens_tensor_cpu: int, block_tables: torch.Tensor,
|
||||
isPrefill: bool, attn_metadata, output) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
||||
scale, seq_lens_tensor_cpu,
|
||||
block_tables, isPrefill, attn_metadata,
|
||||
output)
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
k_cache: List[torch.Tensor],
|
||||
v_cache: List[torch.Tensor],
|
||||
scale: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
isPrefill: bool,
|
||||
attn_metadata,
|
||||
output,
|
||||
seq_lens_tensor_cpu: Optional[int] = None) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scale,
|
||||
block_tables,
|
||||
isPrefill,
|
||||
attn_metadata.attn_mask,
|
||||
attn_metadata.slot_mapping,
|
||||
output,
|
||||
seq_lens_tensor_cpu=seq_lens_tensor_cpu)
|
||||
|
||||
Reference in New Issue
Block a user