From 7330416de3fd2f8c6b9b82fb1ad0adfb9c70d483 Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Wed, 12 Mar 2025 11:33:21 +0800 Subject: [PATCH] [BugFix] Fix bugs when using ascend quantization (#275) ### What this PR does / why we need it? It fixes following bugs: 1. When searching a specific linear quantization implementation from a tool (such as MindIE-Turbo), the mapping of packed linear is required to identify correponding quant type. 2. The exception is narrowed down to ImportError when importing MindIETurboQuantizer to better throw other errors. 3. The api of AscendKVCacheMethod.apply is aligned with that in AscendAttentionBackendImpl. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By performing offline inference: ![image](https://github.com/user-attachments/assets/d63804cf-c060-451f-9cb0-d012e06b5333) --------- Signed-off-by: angazenn Co-authored-by: angazenn --- vllm_ascend/attention.py | 17 +++++++--- vllm_ascend/quantization/quant_config.py | 43 +++++++++++++++++------- vllm_ascend/quantization/quantizer.py | 13 ++++--- 3 files changed, 53 insertions(+), 20 deletions(-) diff --git a/vllm_ascend/attention.py b/vllm_ascend/attention.py index 2aa915c..5771a11 100644 --- a/vllm_ascend/attention.py +++ b/vllm_ascend/attention.py @@ -744,10 +744,19 @@ class AscendAttentionBackendImpl(AttentionImpl): block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None # Details of kv_cache arrangement in attention quantization # are implemented by quant_method. - layer.quant_method.apply(layer, query, key, value, self.key_cache, - self.value_cache, self.scale, - self.seq_lens_tensor_cpu, block_tables, - isPrefill, attn_metadata, output) + layer.quant_method.apply( + layer, + query, + key, + value, + self.key_cache, + self.value_cache, + self.scale, + block_tables, + isPrefill, + attn_metadata, + output, + seq_lens_tensor_cpu=self.seq_lens_tensor_cpu) else: if self.key_cache is not None: torch_npu._npu_reshape_and_cache(key=key, diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 3130142..51f201e 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -88,7 +88,8 @@ class AscendQuantConfig(QuantizationConfig): if self.is_layer_skipped_ascend(prefix, self.packed_modules_mapping): return UnquantizedLinearMethod() - return AscendLinearMethod(self, prefix) + return AscendLinearMethod(self, prefix, + self.packed_modules_mapping) if isinstance(layer, Attention) and \ 'fa_quant_type' in self.quant_description.keys(): return AscendKVCacheMethod(self, prefix) @@ -138,9 +139,10 @@ class AscendLinearMethod(LinearMethodBase): quant_config: The Ascend quantization config. """ - def __init__(self, quant_config: AscendQuantConfig, prefix: str) -> None: + def __init__(self, quant_config: AscendQuantConfig, prefix: str, + packed_modules_mapping: Dict[str, Any]) -> None: self.quantizer = AscendQuantizer.get_quantizer( - quant_config.quant_description, prefix) + quant_config.quant_description, prefix, packed_modules_mapping) self.quant_method = self.quantizer.build_linear_method() def create_weights( @@ -225,12 +227,29 @@ class AscendKVCacheMethod(BaseKVCacheMethod): if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) - def apply(self, layer: torch.nn.Module, query: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - kv_cache: List[torch.Tensor], scale: torch.Tensor, - seq_lens_tensor_cpu: int, block_tables: torch.Tensor, - isPrefill: bool, attn_metadata, output) -> torch.Tensor: - return self.quant_method.apply(layer, query, key, value, kv_cache, - scale, seq_lens_tensor_cpu, - block_tables, isPrefill, attn_metadata, - output) + def apply(self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + k_cache: List[torch.Tensor], + v_cache: List[torch.Tensor], + scale: torch.Tensor, + block_tables: torch.Tensor, + isPrefill: bool, + attn_metadata, + output, + seq_lens_tensor_cpu: Optional[int] = None) -> torch.Tensor: + return self.quant_method.apply(layer, + query, + key, + value, + k_cache, + v_cache, + scale, + block_tables, + isPrefill, + attn_metadata.attn_mask, + attn_metadata.slot_mapping, + output, + seq_lens_tensor_cpu=seq_lens_tensor_cpu) diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index b7c8fe9..eee5159 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -16,7 +16,7 @@ # import importlib -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional CUSTOMIZED_QUANTIZER_TYPE: List[str] = [] @@ -25,7 +25,11 @@ class AscendQuantizer: """An interface to different quantization implementations for ascend hardwares.""" @classmethod - def get_quantizer(cls, quant_config: Dict[str, Any], prefix: str): + def get_quantizer(cls, + quant_config: Dict[str, Any], + prefix: str, + packed_modules_mapping: Optional[Dict[str, + Any]] = dict()): # TODO: Need a param to choose quantization algorithms. quantization_algorithm = '' @@ -35,11 +39,12 @@ class AscendQuantizer: try: module = importlib.import_module("mindie_turbo") MindIETurboQuantizer = module.MindIETurboQuantizer - except Exception: + except ImportError: raise NotImplementedError( "There is no available ascend quantizer.") - return MindIETurboQuantizer.get_quantizer(quant_config, prefix) + return MindIETurboQuantizer.get_quantizer(quant_config, prefix, + packed_modules_mapping) def build_linear_method(self): raise NotImplementedError