From 432b861cae054bbf424dcc498ab6abed548549d0 Mon Sep 17 00:00:00 2001 From: ZYang6263 <50876451+ZYang6263@users.noreply.github.com> Date: Mon, 8 Dec 2025 23:17:45 +0800 Subject: [PATCH] Fix incorrect MLAPO weight release in PD mixex scenarios. (#4774) ### What this PR does / why we need it? Fix incorrect MLAPO weight release in PD mixex scenarios. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: ZYang6263 Co-authored-by: wangxiyuan --- vllm_ascend/attention/sfa_v1.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index b74efc8d..cc443f55 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -470,7 +470,7 @@ class AscendSFAImpl(MLAAttentionImpl): if self.fused_qkv_a_proj is None or not isinstance( quant_method, AscendW8A8LinearMethod): reasons.append( - "Currently mlapo only supports W8A8 quantization in MLA scenario." + "Currently mlapo only supports W8A8 quantization in SFA scenario." "Some layers in your model are not quantized with W8A8," "thus mlapo is disabled for these layers.") if self.enable_sfa_cp: @@ -597,8 +597,6 @@ class AscendSFAImpl(MLAAttentionImpl): q_a_proj_wt = self.fused_qkv_a_proj.weight.data[ ..., :self.q_lora_rank].contiguous() - self.fused_qkv_a_proj.weight = None - kv_a_proj_wt = kv_a_proj_wt.t().contiguous() kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) kv_a_proj_wt = kv_a_proj_wt.t().contiguous() @@ -673,9 +671,12 @@ class AscendSFAImpl(MLAAttentionImpl): self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) - if self.vllm_config.kv_transfer_config is not None: + if self.vllm_config.kv_transfer_config is not None and \ + self.vllm_config.kv_transfer_config.is_kv_consumer: + self.fused_qkv_a_proj.weight = None self.fused_qkv_a_proj.deq_scale = None self.fused_qkv_a_proj.quant_bias = None + self.q_proj.weight = None self.q_proj.deq_scale = None self.q_proj.quant_bias = None torch.npu.empty_cache()