diff --git a/vllm_ascend/ops/mla.py b/vllm_ascend/ops/mla.py index 64d5d36a..6f02cecd 100644 --- a/vllm_ascend/ops/mla.py +++ b/vllm_ascend/ops/mla.py @@ -126,16 +126,17 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): o_proj=mla_modules.o_proj, ) - original_process_weights = self.mla_attn.process_weights_after_loading + if not vllm_version_is("v0.15.0"): + original_process_weights = self.mla_attn.process_weights_after_loading - def wrapped_process_weights(act_dtype: torch.dtype): - from vllm_ascend.attention.sfa_v1 import AscendSFAImpl + def wrapped_process_weights(act_dtype: torch.dtype): + from vllm_ascend.attention.sfa_v1 import AscendSFAImpl - if not isinstance(self.mla_attn.impl, AscendSFAImpl): - original_process_weights(act_dtype) - self.mla_attn.impl.process_weights_after_loading(act_dtype) + if not isinstance(self.mla_attn.impl, AscendSFAImpl): + original_process_weights(act_dtype) + self.mla_attn.impl.process_weights_after_loading(act_dtype) - self.mla_attn.process_weights_after_loading = wrapped_process_weights + self.mla_attn.process_weights_after_loading = wrapped_process_weights compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: