diff --git a/tests/ut/quantization/test_modelslim_config.py b/tests/ut/quantization/test_modelslim_config.py index 745b4736..250c576a 100644 --- a/tests/ut/quantization/test_modelslim_config.py +++ b/tests/ut/quantization/test_modelslim_config.py @@ -253,6 +253,3 @@ class TestAscendModelSlimConfig(TestBase): self.assertIn("model.layers.0.weight", config.quant_description) self.assertEqual(config.quant_description["model.layers.0.weight"], "INT8") - - def test_get_scaled_act_names(self): - self.assertEqual(self.ascend_config.get_scaled_act_names(), []) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index c7c14394..a38ba6b4 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -49,6 +49,7 @@ from vllm_ascend.ops.layer_shard_linear import ( ) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod +from vllm_ascend.quantization.utils import enable_fa_quant from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, get_weight_prefetch_method, maybe_trans_nz, weak_ref_tensors from vllm_ascend.worker.npu_input_batch import NPUInputBatch @@ -730,10 +731,7 @@ class AscendMLAImpl(MLAAttentionImpl): self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer ) self.layer_name = kwargs.get("layer_name") - quant_config = self.vllm_config.quant_config - self.fa_quant_layer = ( - quant_config.enabling_fa_quant(self.vllm_config, self.layer_name) if quant_config is not None else False - ) + self.fa_quant_layer = enable_fa_quant(self.vllm_config, self.layer_name) self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype self.layer_sharding_kwargs = [] for layer_name in get_ascend_config().layer_sharding or []: diff --git a/vllm_ascend/quantization/modelslim_config.py b/vllm_ascend/quantization/modelslim_config.py index b74242b2..2e8d6212 100644 --- a/vllm_ascend/quantization/modelslim_config.py +++ b/vllm_ascend/quantization/modelslim_config.py @@ -660,9 +660,6 @@ class AscendModelSlimConfig(QuantizationConfig): extra_quant_dict[new_k] = self.quant_description[k] self.quant_description.update(extra_quant_dict) - def get_scaled_act_names(self) -> list[str]: - return [] - def _add_kvcache_quant_metadata(self): fa_quant_type = self.quant_description.get("fa_quant_type", "") self.enable_fa_quant = fa_quant_type != "" diff --git a/vllm_ascend/quantization/utils.py b/vllm_ascend/quantization/utils.py index 0ec87437..7f83f3de 100644 --- a/vllm_ascend/quantization/utils.py +++ b/vllm_ascend/quantization/utils.py @@ -197,3 +197,12 @@ def maybe_auto_detect_quantization(vllm_config) -> None: from vllm.config import VllmConfig as _VllmConfig vllm_config.quant_config = _VllmConfig._get_quantization_config(model_config, vllm_config.load_config) + + +def enable_fa_quant(vllm_config, layer_name=None) -> bool: + if vllm_config.quant_config is not None and getattr(vllm_config.quant_config, "enable_fa_quant", False): + if layer_name is not None: + return vllm_config.quant_config.enabling_fa_quant(vllm_config, layer_name) + else: + return True + return False diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f414d4f9..36546b26 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -109,6 +109,7 @@ from vllm_ascend.eplb.utils import model_register from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin from vllm_ascend.patch.worker.patch_draft_quarot import patch_load_weights from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort +from vllm_ascend.quantization.utils import enable_fa_quant from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.draft_proposer import AscendDraftModelProposer @@ -2763,7 +2764,7 @@ class NPUModelRunner(GPUModelRunner): k_dim, v_dim, ] - if self.is_kv_consumer and self.vllm_config.quant_config is not None: + if self.is_kv_consumer and enable_fa_quant(self.vllm_config): k_tensor_split_factor, v_tensor_split_factor = ( self.vllm_config.quant_config.get_kv_quant_split_factor(layer_name, kv_head_dim_list) ) @@ -2950,7 +2951,7 @@ class NPUModelRunner(GPUModelRunner): v_dim, ) k_cache_dtype = v_cache_dtype = current_kv_cache_spec.dtype - if self.is_kv_consumer and self.vllm_config.quant_config is not None: + if self.is_kv_consumer and enable_fa_quant(self.vllm_config): k_cache_dtype, v_cache_dtype = self.vllm_config.quant_config.get_kv_quant_dtype( layer_name, current_kv_cache_spec.dtype, self.model_config )