From a2daacbd7157a315f1dd07e9a0b37f8dda1ea9d2 Mon Sep 17 00:00:00 2001 From: Chen Chen <0109chenchen@gmail.com> Date: Mon, 5 Jan 2026 21:29:45 +0800 Subject: [PATCH] [perf] Fix MLAPO weight disposal for KV-consumer MLA in PD-mix deploy... (#5192) ### What this PR does / why we need it? - Problem: In MLA+MLAPO, KV-consumer deployments keep fused_qkv_a_proj/q_proj weights and quant params even though MLAPO uses the prepacked buffers, increasing memory footprint on decode nodes. - Fix: Conditionally drop those tensors only when `kv_transfer_config.is_kv_consumer` to reclaim memory (consistent with the SFA behavior #4774 ). ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: Chen Chen <0109chenchen@gmail.com> --- vllm_ascend/attention/mla_v1.py | 15 +++++++++++++++ vllm_ascend/attention/sfa_v1.py | 11 ++++++++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 6d6b0db5..aa4ed077 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -912,6 +912,21 @@ class AscendMLAImpl(MLAAttentionImpl): self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) + # On KV consumers (decode-only) MLAPO uses the transformed weights built above; + # the original fused_qkv_a_proj/q_proj weights and quant params are no longer + # referenced, so drop them to save memory. + ascend_config = get_ascend_config() + if self.vllm_config.kv_transfer_config is not None and \ + self.vllm_config.kv_transfer_config.is_kv_consumer and \ + ascend_config.recompute_scheduler_enable: + self.fused_qkv_a_proj.weight = None + self.fused_qkv_a_proj.deq_scale = None + self.fused_qkv_a_proj.quant_bias = None + self.q_proj.weight = None + self.q_proj.deq_scale = None + self.q_proj.quant_bias = None + torch.npu.empty_cache() + def get_context_seq_len_npu(self, index: int, attn_metadata: AscendMLAMetadata): prefill_metadata = attn_metadata.prefill diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 6588686e..12ac00bc 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -371,7 +371,7 @@ class AscendSFAImpl(MLAAttentionImpl): if self.enable_sfa_cp: self.local_num_heads = self.num_heads * self.tp_size - #TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97 + # TODO: Temporarily adapt sfa-cp, remove after adapting near PCP. --clrs97 self._replace_linear_class_for_sfa_cp() from vllm_ascend.distributed.parallel_state import \ get_shared_weight_group @@ -537,7 +537,7 @@ class AscendSFAImpl(MLAAttentionImpl): cache_mode=cache_mode, is_output_kv=True, ) - #TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97 + # TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97 k_pe = get_tp_group().all_gather(k_pe, 0) k_nope = get_tp_group().all_gather(k_nope, 0) @@ -659,8 +659,13 @@ class AscendSFAImpl(MLAAttentionImpl): self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) + # On KV consumers (decode-only) MLAPO uses the transformed weights built above; + # the original fused_qkv_a_proj/q_proj weights and quant params are no longer + # referenced, so drop them to save memory. + ascend_config = get_ascend_config() if self.vllm_config.kv_transfer_config is not None and \ - self.vllm_config.kv_transfer_config.is_kv_consumer: + self.vllm_config.kv_transfer_config.is_kv_consumer and \ + ascend_config.recompute_scheduler_enable: self.fused_qkv_a_proj.weight = None self.fused_qkv_a_proj.deq_scale = None self.fused_qkv_a_proj.quant_bias = None