diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 5414ee58..03834212 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -22,7 +22,7 @@ from vllm_ascend.attention.utils import ( AscendCommonAttentionMetadata, ascend_chunked_prefill_workspace_size, enable_cp, - enabling_malpo, + enabling_mlapo, maybe_save_kv_layer_to_connector, split_decodes_and_prefills, trans_rope_weight, @@ -710,7 +710,7 @@ class AscendMLAImpl(MLAAttentionImpl): self.ring_mla_mask_size = 512 self.speculative_config = self.vllm_config.speculative_config - self.enable_mlapo = enabling_malpo(self.vllm_config) + self.enable_mlapo = enabling_mlapo(self.vllm_config) self.is_kv_producer = ( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 8414fc5d..8eb91ec5 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -305,6 +305,10 @@ def transdata(nd_mat, block_size: tuple = (16, 16)): return nz_mat -def enabling_malpo(vllm_config: VllmConfig) -> bool: - is_decode_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_consumer +def enabling_mlapo(vllm_config: VllmConfig) -> bool: + is_decode_instance = ( + vllm_config.kv_transfer_config is not None + and vllm_config.kv_transfer_config.is_kv_consumer + and not vllm_config.kv_transfer_config.is_kv_producer + ) return bool(envs.VLLM_ASCEND_ENABLE_MLAPO and is_decode_instance)